mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Working on feature extraction, interfaces refined, a number of models working, some in progress.
This commit is contained in:
parent
24e7535278
commit
d23a2697d0
@ -7,3 +7,4 @@ from .transforms_factory import create_transform
|
|||||||
from .mixup import mixup_batch, FastCollateMixup
|
from .mixup import mixup_batch, FastCollateMixup
|
||||||
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
|
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
|
||||||
rand_augment_transform, auto_augment_transform
|
rand_augment_transform, auto_augment_transform
|
||||||
|
from .real_labels import RealLabelsImagenet
|
||||||
|
36
timm/data/real_labels.py
Normal file
36
timm/data/real_labels.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class RealLabelsImagenet:
|
||||||
|
|
||||||
|
def __init__(self, filenames, real_json='real.json', topk=(1, 5)):
|
||||||
|
with open(real_json) as real_labels:
|
||||||
|
real_labels = json.load(real_labels)
|
||||||
|
real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)}
|
||||||
|
self.real_labels = real_labels
|
||||||
|
self.filenames = filenames
|
||||||
|
assert len(self.filenames) == len(self.real_labels)
|
||||||
|
self.topk = topk
|
||||||
|
self.is_correct = {k: [] for k in topk}
|
||||||
|
self.sample_idx = 0
|
||||||
|
|
||||||
|
def add_result(self, output):
|
||||||
|
maxk = max(self.topk)
|
||||||
|
_, pred_batch = output.topk(maxk, 1, True, True)
|
||||||
|
pred_batch = pred_batch.cpu().numpy()
|
||||||
|
for pred in pred_batch:
|
||||||
|
filename = self.filenames[self.sample_idx]
|
||||||
|
filename = os.path.basename(filename)
|
||||||
|
if self.real_labels[filename]:
|
||||||
|
for k in self.topk:
|
||||||
|
self.is_correct[k].append(
|
||||||
|
any([p in self.real_labels[filename] for p in pred[:k]]))
|
||||||
|
self.sample_idx += 1
|
||||||
|
|
||||||
|
def get_accuracy(self, k=None):
|
||||||
|
if k is None:
|
||||||
|
return {k: float(np.mean(self.is_correct[k] for k in self.topk))}
|
||||||
|
else:
|
||||||
|
return float(np.mean(self.is_correct[k])) * 100
|
@ -13,6 +13,7 @@ import torch.utils.checkpoint as cp
|
|||||||
from torch.jit.annotations import List
|
from torch.jit.annotations import List
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from .features import FeatureNet
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act, BlurPool2d
|
from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act, BlurPool2d
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
@ -199,6 +200,9 @@ class DenseNet(nn.Module):
|
|||||||
('norm0', norm_layer(num_init_features)),
|
('norm0', norm_layer(num_init_features)),
|
||||||
('pool0', stem_pool),
|
('pool0', stem_pool),
|
||||||
]))
|
]))
|
||||||
|
self.feature_info = [
|
||||||
|
dict(num_chs=num_init_features, reduction=2, module=f'features.norm{2 if deep_stem else 0}')]
|
||||||
|
current_stride = 4
|
||||||
|
|
||||||
# DenseBlocks
|
# DenseBlocks
|
||||||
num_features = num_init_features
|
num_features = num_init_features
|
||||||
@ -212,21 +216,27 @@ class DenseNet(nn.Module):
|
|||||||
drop_rate=drop_rate,
|
drop_rate=drop_rate,
|
||||||
memory_efficient=memory_efficient
|
memory_efficient=memory_efficient
|
||||||
)
|
)
|
||||||
self.features.add_module('denseblock%d' % (i + 1), block)
|
module_name = f'denseblock{(i + 1)}'
|
||||||
|
self.features.add_module(module_name, block)
|
||||||
num_features = num_features + num_layers * growth_rate
|
num_features = num_features + num_layers * growth_rate
|
||||||
transition_aa_layer = None if aa_stem_only else aa_layer
|
transition_aa_layer = None if aa_stem_only else aa_layer
|
||||||
if i != len(block_config) - 1:
|
if i != len(block_config) - 1:
|
||||||
|
self.feature_info += [
|
||||||
|
dict(num_chs=num_features, reduction=current_stride, module='features.' + module_name)]
|
||||||
|
current_stride *= 2
|
||||||
trans = DenseTransition(
|
trans = DenseTransition(
|
||||||
num_input_features=num_features, num_output_features=num_features // 2,
|
num_input_features=num_features, num_output_features=num_features // 2,
|
||||||
norm_layer=norm_layer, aa_layer=transition_aa_layer)
|
norm_layer=norm_layer, aa_layer=transition_aa_layer)
|
||||||
self.features.add_module('transition%d' % (i + 1), trans)
|
self.features.add_module(f'transition{i + 1}', trans)
|
||||||
num_features = num_features // 2
|
num_features = num_features // 2
|
||||||
|
|
||||||
# Final batch norm
|
# Final batch norm
|
||||||
self.features.add_module('norm5', norm_layer(num_features))
|
self.features.add_module('norm5', norm_layer(num_features))
|
||||||
|
|
||||||
# Linear layer
|
self.feature_info += [dict(num_chs=num_features, reduction=current_stride, module='features.norm5')]
|
||||||
self.num_features = num_features
|
self.num_features = num_features
|
||||||
|
|
||||||
|
# Linear layer
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||||
|
|
||||||
@ -279,16 +289,14 @@ def _filter_torchvision_pretrained(state_dict):
|
|||||||
|
|
||||||
|
|
||||||
def _densenet(variant, growth_rate, block_config, pretrained, **kwargs):
|
def _densenet(variant, growth_rate, block_config, pretrained, **kwargs):
|
||||||
|
features = False
|
||||||
|
out_indices = None
|
||||||
if kwargs.pop('features_only', False):
|
if kwargs.pop('features_only', False):
|
||||||
assert False, 'Not Implemented' # TODO
|
features = True
|
||||||
load_strict = False
|
|
||||||
kwargs.pop('num_classes', 0)
|
kwargs.pop('num_classes', 0)
|
||||||
model_class = DenseNet
|
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
|
||||||
else:
|
|
||||||
load_strict = True
|
|
||||||
model_class = DenseNet
|
|
||||||
default_cfg = default_cfgs[variant]
|
default_cfg = default_cfgs[variant]
|
||||||
model = model_class(growth_rate=growth_rate, block_config=block_config, **kwargs)
|
model = DenseNet(growth_rate=growth_rate, block_config=block_config, **kwargs)
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(
|
load_pretrained(
|
||||||
@ -296,7 +304,9 @@ def _densenet(variant, growth_rate, block_config, pretrained, **kwargs):
|
|||||||
num_classes=kwargs.get('num_classes', 0),
|
num_classes=kwargs.get('num_classes', 0),
|
||||||
in_chans=kwargs.get('in_chans', 3),
|
in_chans=kwargs.get('in_chans', 3),
|
||||||
filter_fn=_filter_torchvision_pretrained,
|
filter_fn=_filter_torchvision_pretrained,
|
||||||
strict=load_strict)
|
strict=not features)
|
||||||
|
if features:
|
||||||
|
model = FeatureNet(model, out_indices, flatten_sequential=True)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
|
|||||||
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||||
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
||||||
from .feature_hooks import FeatureHooks
|
from .feature_hooks import FeatureHooks
|
||||||
|
from .features import FeatureInfo
|
||||||
from .helpers import load_pretrained, adapt_model_from_file
|
from .helpers import load_pretrained, adapt_model_from_file
|
||||||
from .layers import SelectAdaptivePool2d, create_conv2d
|
from .layers import SelectAdaptivePool2d, create_conv2d
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
@ -438,42 +439,22 @@ class EfficientNetFeatures(nn.Module):
|
|||||||
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
|
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
|
||||||
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||||
self._feature_info = builder.features # builder provides info about feature channels for each block
|
self.feature_info = FeatureInfo(builder.features, out_indices)
|
||||||
self._stage_to_feature_idx = {
|
self._stage_to_feature_idx = {
|
||||||
v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices}
|
v['stage_idx']: fi for fi, v in enumerate(self.feature_info) if fi in self.out_indices}
|
||||||
self._in_chs = builder.in_chs
|
self._in_chs = builder.in_chs
|
||||||
|
|
||||||
efficientnet_init_weights(self)
|
efficientnet_init_weights(self)
|
||||||
if _DEBUG:
|
if _DEBUG:
|
||||||
for k, v in self._feature_info.items():
|
for fi, v in enumerate(self.feature_info):
|
||||||
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
|
print('Feature idx: {}: Name: {}, Channels: {}'.format(fi, v['module'], v['num_chs']))
|
||||||
|
|
||||||
# Register feature extraction hooks with FeatureHooks helper
|
# Register feature extraction hooks with FeatureHooks helper
|
||||||
self.feature_hooks = None
|
self.feature_hooks = None
|
||||||
if feature_location != 'bottleneck':
|
if feature_location != 'bottleneck':
|
||||||
hooks = [dict(
|
hooks = self.feature_info.get_by_key(keys=('module', 'hook_type'))
|
||||||
name=self._feature_info[idx]['module'],
|
|
||||||
type=self._feature_info[idx]['hook_type']) for idx in out_indices]
|
|
||||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||||
|
|
||||||
def feature_channels(self, idx=None):
|
|
||||||
""" Feature Channel Shortcut
|
|
||||||
Returns feature channel count for each output index if idx == None. If idx is an integer, will
|
|
||||||
return feature channel count for that feature block index (independent of out_indices setting).
|
|
||||||
"""
|
|
||||||
if isinstance(idx, int):
|
|
||||||
return self._feature_info[idx]['num_chs']
|
|
||||||
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
|
|
||||||
|
|
||||||
def feature_info(self, idx=None):
|
|
||||||
""" Feature Channel Shortcut
|
|
||||||
Returns feature channel count for each output index if idx == None. If idx is an integer, will
|
|
||||||
return feature channel count for that feature block index (independent of out_indices setting).
|
|
||||||
"""
|
|
||||||
if isinstance(idx, int):
|
|
||||||
return self._feature_info[idx]
|
|
||||||
return [self._feature_info[i] for i in self.out_indices]
|
|
||||||
|
|
||||||
def forward(self, x) -> List[torch.Tensor]:
|
def forward(self, x) -> List[torch.Tensor]:
|
||||||
x = self.conv_stem(x)
|
x = self.conv_stem(x)
|
||||||
x = self.bn1(x)
|
x = self.bn1(x)
|
||||||
|
@ -225,7 +225,7 @@ class EfficientNetBuilder:
|
|||||||
|
|
||||||
# state updated during build, consumed by model
|
# state updated during build, consumed by model
|
||||||
self.in_chs = None
|
self.in_chs = None
|
||||||
self.features = OrderedDict()
|
self.features = []
|
||||||
|
|
||||||
def _round_channels(self, chs):
|
def _round_channels(self, chs):
|
||||||
return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
|
return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
|
||||||
@ -291,7 +291,6 @@ class EfficientNetBuilder:
|
|||||||
total_block_idx = 0
|
total_block_idx = 0
|
||||||
current_stride = 2
|
current_stride = 2
|
||||||
current_dilation = 1
|
current_dilation = 1
|
||||||
feature_idx = 0
|
|
||||||
stages = []
|
stages = []
|
||||||
# outer list of block_args defines the stacks ('stages' by some conventions)
|
# outer list of block_args defines the stacks ('stages' by some conventions)
|
||||||
for stage_idx, stage_block_args in enumerate(model_block_args):
|
for stage_idx, stage_block_args in enumerate(model_block_args):
|
||||||
@ -351,13 +350,15 @@ class EfficientNetBuilder:
|
|||||||
# stash feature module name and channel info for model feature extraction
|
# stash feature module name and channel info for model feature extraction
|
||||||
if extract_features:
|
if extract_features:
|
||||||
feature_info = block.feature_info(extract_features)
|
feature_info = block.feature_info(extract_features)
|
||||||
if feature_info['module']:
|
module_name = f'blocks.{stage_idx}.{block_idx}'
|
||||||
feature_info['module'] = 'blocks.{}.{}.'.format(stage_idx, block_idx) + feature_info['module']
|
if 'module' in feature_info and feature_info['module']:
|
||||||
|
feature_info['module'] = '.'.join([module_name, feature_info['module']])
|
||||||
|
else:
|
||||||
|
feature_info['module'] = module_name
|
||||||
feature_info['stage_idx'] = stage_idx
|
feature_info['stage_idx'] = stage_idx
|
||||||
feature_info['block_idx'] = block_idx
|
feature_info['block_idx'] = block_idx
|
||||||
feature_info['reduction'] = current_stride
|
feature_info['reduction'] = current_stride
|
||||||
self.features[feature_idx] = feature_info
|
self.features.append(feature_info)
|
||||||
feature_idx += 1
|
|
||||||
|
|
||||||
total_block_idx += 1 # incr global block idx (across all stacks)
|
total_block_idx += 1 # incr global block idx (across all stacks)
|
||||||
stages.append(nn.Sequential(*blocks))
|
stages.append(nn.Sequential(*blocks))
|
||||||
|
@ -1,3 +1,9 @@
|
|||||||
|
""" PyTorch Feature Hook Helper
|
||||||
|
|
||||||
|
This class helps gather features from a network via hooks specified on the module name.
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from collections import defaultdict, OrderedDict
|
from collections import defaultdict, OrderedDict
|
||||||
@ -7,20 +13,21 @@ from typing import List
|
|||||||
|
|
||||||
class FeatureHooks:
|
class FeatureHooks:
|
||||||
|
|
||||||
def __init__(self, hooks, named_modules):
|
def __init__(self, hooks, named_modules, output_as_dict=False):
|
||||||
# setup feature hooks
|
# setup feature hooks
|
||||||
modules = {k: v for k, v in named_modules}
|
modules = {k: v for k, v in named_modules}
|
||||||
for h in hooks:
|
for h in hooks:
|
||||||
hook_name = h['name']
|
hook_name = h['module']
|
||||||
m = modules[hook_name]
|
m = modules[hook_name]
|
||||||
hook_fn = partial(self._collect_output_hook, hook_name)
|
hook_fn = partial(self._collect_output_hook, hook_name)
|
||||||
if h['type'] == 'forward_pre':
|
if h['hook_type'] == 'forward_pre':
|
||||||
m.register_forward_pre_hook(hook_fn)
|
m.register_forward_pre_hook(hook_fn)
|
||||||
elif h['type'] == 'forward':
|
elif h['hook_type'] == 'forward':
|
||||||
m.register_forward_hook(hook_fn)
|
m.register_forward_hook(hook_fn)
|
||||||
else:
|
else:
|
||||||
assert False, "Unsupported hook type"
|
assert False, "Unsupported hook type"
|
||||||
self._feature_outputs = defaultdict(OrderedDict)
|
self._feature_outputs = defaultdict(OrderedDict)
|
||||||
|
self.output_as_dict = output_as_dict
|
||||||
|
|
||||||
def _collect_output_hook(self, name, *args):
|
def _collect_output_hook(self, name, *args):
|
||||||
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
|
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
|
||||||
@ -29,6 +36,9 @@ class FeatureHooks:
|
|||||||
self._feature_outputs[x.device][name] = x
|
self._feature_outputs[x.device][name] = x
|
||||||
|
|
||||||
def get_output(self, device) -> List[torch.tensor]:
|
def get_output(self, device) -> List[torch.tensor]:
|
||||||
output = list(self._feature_outputs[device].values())
|
if self.output_as_dict:
|
||||||
|
output = self._feature_outputs[device]
|
||||||
|
else:
|
||||||
|
output = list(self._feature_outputs[device].values())
|
||||||
self._feature_outputs[device] = OrderedDict() # clear after reading
|
self._feature_outputs[device] = OrderedDict() # clear after reading
|
||||||
return output
|
return output
|
||||||
|
251
timm/models/features.py
Normal file
251
timm/models/features.py
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
""" PyTorch Feature Extraction Helpers
|
||||||
|
|
||||||
|
A collection of classes, functions, modules to help extract features from models
|
||||||
|
and provide a common interface for describing them.
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Dict, List, Tuple, Any
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureInfo:
|
||||||
|
|
||||||
|
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
|
||||||
|
prev_reduction = 1
|
||||||
|
for fi in feature_info:
|
||||||
|
# sanity check the mandatory fields, there may be additional fields depending on the model
|
||||||
|
assert 'num_chs' in fi and fi['num_chs'] > 0
|
||||||
|
assert 'reduction' in fi and fi['reduction'] >= prev_reduction
|
||||||
|
prev_reduction = fi['reduction']
|
||||||
|
assert 'module' in fi
|
||||||
|
self._out_indices = out_indices
|
||||||
|
self._info = feature_info
|
||||||
|
|
||||||
|
def from_other(self, out_indices: Tuple[int]):
|
||||||
|
return FeatureInfo(deepcopy(self._info), out_indices)
|
||||||
|
|
||||||
|
def channels(self, idx=None):
|
||||||
|
""" feature channels accessor
|
||||||
|
if idx == None, returns feature channel count at each output index
|
||||||
|
if idx is an integer, return feature channel count for that feature module index
|
||||||
|
"""
|
||||||
|
if isinstance(idx, int):
|
||||||
|
return self._info[idx]['num_chs']
|
||||||
|
return [self._info[i]['num_chs'] for i in self._out_indices]
|
||||||
|
|
||||||
|
def reduction(self, idx=None):
|
||||||
|
""" feature reduction (output stride) accessor
|
||||||
|
if idx == None, returns feature reduction factor at each output index
|
||||||
|
if idx is an integer, return feature channel count at that feature module index
|
||||||
|
"""
|
||||||
|
if isinstance(idx, int):
|
||||||
|
return self._info[idx]['reduction']
|
||||||
|
return [self._info[i]['reduction'] for i in self._out_indices]
|
||||||
|
|
||||||
|
def module_name(self, idx=None):
|
||||||
|
""" feature module name accessor
|
||||||
|
if idx == None, returns feature module name at each output index
|
||||||
|
if idx is an integer, return feature module name at that feature module index
|
||||||
|
"""
|
||||||
|
if isinstance(idx, int):
|
||||||
|
return self._info[idx]['module']
|
||||||
|
return [self._info[i]['module'] for i in self._out_indices]
|
||||||
|
|
||||||
|
def get_by_key(self, idx=None, keys=None):
|
||||||
|
""" return info dicts for specified keys (or all if None) at specified idx (or out_indices if None)
|
||||||
|
"""
|
||||||
|
if isinstance(idx, int):
|
||||||
|
return self._info[idx] if keys is None else {k: self._info[idx][k] for k in keys}
|
||||||
|
if keys is None:
|
||||||
|
return [self._info[i] for i in self._out_indices]
|
||||||
|
else:
|
||||||
|
return [{k: self._info[i][k] for k in keys} for i in self._out_indices]
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self._info[item]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._info)
|
||||||
|
|
||||||
|
|
||||||
|
def _module_list(module, flatten_sequential=False):
|
||||||
|
# a yield/iter would be better for this but wouldn't be compatible with torchscript
|
||||||
|
ml = []
|
||||||
|
for name, module in module.named_children():
|
||||||
|
if flatten_sequential and isinstance(module, nn.Sequential):
|
||||||
|
# first level of Sequential containers is flattened into containing model
|
||||||
|
for child_name, child_module in module.named_children():
|
||||||
|
ml.append(('_'.join([name, child_name]), child_module))
|
||||||
|
else:
|
||||||
|
ml.append((name, module))
|
||||||
|
return ml
|
||||||
|
|
||||||
|
|
||||||
|
def _check_return_layers(input_return_layers, modules):
|
||||||
|
return_layers = {}
|
||||||
|
for k, v in input_return_layers.items():
|
||||||
|
ks = k.split('.')
|
||||||
|
assert 0 < len(ks) <= 2
|
||||||
|
return_layers['_'.join(ks)] = v
|
||||||
|
return_set = set(return_layers.keys())
|
||||||
|
sdiff = return_set - {name for name, _ in modules}
|
||||||
|
if sdiff:
|
||||||
|
raise ValueError(f'return_layers {sdiff} are not present in model')
|
||||||
|
return return_layers, return_set
|
||||||
|
|
||||||
|
|
||||||
|
class LayerGetterDict(nn.ModuleDict):
|
||||||
|
"""
|
||||||
|
Module wrapper that returns intermediate layers from a model as a dictionary
|
||||||
|
|
||||||
|
Originally based on IntermediateLayerGetter at
|
||||||
|
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
|
||||||
|
|
||||||
|
It has a strong assumption that the modules have been registered into the model in the same
|
||||||
|
order as they are used. This means that one should **not** reuse the same nn.Module twice
|
||||||
|
in the forward if you want this to work.
|
||||||
|
|
||||||
|
Additionally, it is only able to query submodules that are directly assigned to the model
|
||||||
|
class (`model.feature1`) or at most one Sequential container deep (`model.features.1`, so
|
||||||
|
long as `features` is a sequential container assigned to the model).
|
||||||
|
|
||||||
|
All Sequential containers that are directly assigned to the original model will have their
|
||||||
|
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
model (nn.Module): model on which we will extract the features
|
||||||
|
return_layers (Dict[name, new_name]): a dict containing the names
|
||||||
|
of the modules for which the activations will be returned as
|
||||||
|
the key of the dict, and the value of the dict is the name
|
||||||
|
of the returned activation (which the user can specify).
|
||||||
|
concat (bool): whether to concatenate intermediate features that are lists or tuples
|
||||||
|
vs select element [0]
|
||||||
|
flatten_sequential (bool): whether to flatten sequential modules assigned to model
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model, return_layers, concat=False, flatten_sequential=False):
|
||||||
|
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||||
|
self.return_layers, remaining = _check_return_layers(return_layers, modules)
|
||||||
|
layers = OrderedDict()
|
||||||
|
self.concat = concat
|
||||||
|
for name, module in modules:
|
||||||
|
layers[name] = module
|
||||||
|
if name in remaining:
|
||||||
|
remaining.remove(name)
|
||||||
|
if not remaining:
|
||||||
|
break
|
||||||
|
super(LayerGetterDict, self).__init__(layers)
|
||||||
|
|
||||||
|
def forward(self, x) -> Dict[Any, torch.Tensor]:
|
||||||
|
out = OrderedDict()
|
||||||
|
for name, module in self.items():
|
||||||
|
x = module(x)
|
||||||
|
if name in self.return_layers:
|
||||||
|
out_id = self.return_layers[name]
|
||||||
|
if isinstance(x, (tuple, list)):
|
||||||
|
# If model tap is a tuple or list, concat or select first element
|
||||||
|
# FIXME this may need to be more generic / flexible for some nets
|
||||||
|
out[out_id] = torch.cat(x, 1) if self.concat else x[0]
|
||||||
|
else:
|
||||||
|
out[out_id] = x
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class LayerGetterList(nn.Sequential):
|
||||||
|
"""
|
||||||
|
Module wrapper that returns intermediate layers from a model as a list
|
||||||
|
|
||||||
|
Originally based on IntermediateLayerGetter at
|
||||||
|
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
|
||||||
|
|
||||||
|
It has a strong assumption that the modules have been registered into the model in the same
|
||||||
|
order as they are used. This means that one should **not** reuse the same nn.Module twice
|
||||||
|
in the forward if you want this to work.
|
||||||
|
|
||||||
|
Additionally, it is only able to query submodules that are directly assigned to the model
|
||||||
|
class (`model.feature1`) or at most one Sequential container deep (`model.features.1`) so
|
||||||
|
long as `features` is a sequential container assigned to the model and flatten_sequent=True.
|
||||||
|
|
||||||
|
All Sequential containers that are directly assigned to the original model will have their
|
||||||
|
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
model (nn.Module): model on which we will extract the features
|
||||||
|
return_layers (Dict[name, new_name]): a dict containing the names
|
||||||
|
of the modules for which the activations will be returned as
|
||||||
|
the key of the dict, and the value of the dict is the name
|
||||||
|
of the returned activation (which the user can specify).
|
||||||
|
concat (bool): whether to concatenate intermediate features that are lists or tuples
|
||||||
|
vs select element [0]
|
||||||
|
flatten_sequential (bool): whether to flatten sequential modules assigned to model
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model, return_layers, concat=False, flatten_sequential=False):
|
||||||
|
super(LayerGetterList, self).__init__()
|
||||||
|
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||||
|
self.return_layers, remaining = _check_return_layers(return_layers, modules)
|
||||||
|
self.concat = concat
|
||||||
|
for name, module in modules:
|
||||||
|
self.add_module(name, module)
|
||||||
|
if name in remaining:
|
||||||
|
remaining.remove(name)
|
||||||
|
if not remaining:
|
||||||
|
break
|
||||||
|
|
||||||
|
def forward(self, x) -> List[torch.Tensor]:
|
||||||
|
out = []
|
||||||
|
for name, module in self.named_children():
|
||||||
|
x = module(x)
|
||||||
|
if name in self.return_layers:
|
||||||
|
if isinstance(x, (tuple, list)):
|
||||||
|
# If model tap is a tuple or list, concat or select first element
|
||||||
|
# FIXME this may need to be more generic / flexible for some nets
|
||||||
|
out.append(torch.cat(x, 1) if self.concat else x[0])
|
||||||
|
else:
|
||||||
|
out.append(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_feature_info(net, out_indices, feature_info=None):
|
||||||
|
if feature_info is None:
|
||||||
|
feature_info = getattr(net, 'feature_info')
|
||||||
|
if isinstance(feature_info, FeatureInfo):
|
||||||
|
return feature_info.from_other(out_indices)
|
||||||
|
elif isinstance(feature_info, (list, tuple)):
|
||||||
|
return FeatureInfo(net.feature_info, out_indices)
|
||||||
|
else:
|
||||||
|
assert False, "Provided feature_info is not valid"
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureNet(nn.Module):
|
||||||
|
""" FeatureNet
|
||||||
|
|
||||||
|
Wrap a model and extract features as specified by the out indices, the network
|
||||||
|
is partially re-built from contained modules using the LayerGetters.
|
||||||
|
|
||||||
|
Please read the docstrings of the LayerGetter classes, they will not work on all models.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self, net,
|
||||||
|
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False,
|
||||||
|
feature_info=None, feature_concat=False, flatten_sequential=False):
|
||||||
|
super(FeatureNet, self).__init__()
|
||||||
|
self.feature_info = _resolve_feature_info(net, out_indices, feature_info)
|
||||||
|
module_names = self.feature_info.module_name()
|
||||||
|
return_layers = {}
|
||||||
|
for i in range(len(out_indices)):
|
||||||
|
return_layers[module_names[i]] = out_map[i] if out_map is not None else out_indices[i]
|
||||||
|
lg_args = dict(return_layers=return_layers, concat=feature_concat, flatten_sequential=flatten_sequential)
|
||||||
|
self.body = LayerGetterDict(net, **lg_args) if out_as_dict else LayerGetterList(net, **lg_args)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = self.body(x)
|
||||||
|
return output
|
@ -13,16 +13,16 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .layers import SelectAdaptivePool2d
|
from .layers import SelectAdaptivePool2d, get_padding
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
|
|
||||||
__all__ = ['Xception65', 'Xception71']
|
__all__ = ['Xception65']
|
||||||
|
|
||||||
default_cfgs = {
|
default_cfgs = {
|
||||||
'gluon_xception65': {
|
'gluon_xception65': {
|
||||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_xception-7015a15c.pth',
|
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_xception-7015a15c.pth',
|
||||||
'input_size': (3, 299, 299),
|
'input_size': (3, 299, 299),
|
||||||
'crop_pct': 0.875,
|
'crop_pct': 0.903,
|
||||||
'pool_size': (10, 10),
|
'pool_size': (10, 10),
|
||||||
'interpolation': 'bicubic',
|
'interpolation': 'bicubic',
|
||||||
'mean': IMAGENET_DEFAULT_MEAN,
|
'mean': IMAGENET_DEFAULT_MEAN,
|
||||||
@ -32,52 +32,13 @@ default_cfgs = {
|
|||||||
'classifier': 'fc'
|
'classifier': 'fc'
|
||||||
# The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
|
# The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
|
||||||
},
|
},
|
||||||
'gluon_xception71': {
|
|
||||||
'url': '',
|
|
||||||
'input_size': (3, 299, 299),
|
|
||||||
'crop_pct': 0.875,
|
|
||||||
'pool_size': (5, 5),
|
|
||||||
'interpolation': 'bicubic',
|
|
||||||
'mean': IMAGENET_DEFAULT_MEAN,
|
|
||||||
'std': IMAGENET_DEFAULT_STD,
|
|
||||||
'num_classes': 1000,
|
|
||||||
'first_conv': 'conv1',
|
|
||||||
'classifier': 'fc'
|
|
||||||
# The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
""" PADDING NOTES
|
""" PADDING NOTES
|
||||||
The original PyTorch and Gluon impl of these models dutifully reproduced the
|
The original PyTorch and Gluon impl of these models dutifully reproduced the
|
||||||
aligned padding added to Tensorflow models for Deeplab. This padding was compensating
|
aligned padding added to Tensorflow models for Deeplab. This padding was compensating
|
||||||
for Tensorflow 'SAME' padding. PyTorch symmetric padding behaves the way we'd want it to.
|
for Tensorflow 'SAME' padding. PyTorch symmetric padding behaves the way we'd want it to.
|
||||||
|
|
||||||
So, I'm phasing out the 'fixed_padding' ported from TF and replacing with normal
|
|
||||||
PyTorch padding, some asserts to validate the equivalence for any scenario we'd
|
|
||||||
care about before removing altogether.
|
|
||||||
"""
|
"""
|
||||||
_USE_FIXED_PAD = False
|
|
||||||
|
|
||||||
|
|
||||||
def _pytorch_padding(kernel_size, stride=1, dilation=1, **_):
|
|
||||||
if _USE_FIXED_PAD:
|
|
||||||
return 0 # FIXME remove once verified
|
|
||||||
else:
|
|
||||||
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
|
||||||
|
|
||||||
# FIXME remove once verified
|
|
||||||
fp = _fixed_padding(kernel_size, dilation)
|
|
||||||
assert all(padding == p for p in fp)
|
|
||||||
|
|
||||||
return padding
|
|
||||||
|
|
||||||
|
|
||||||
def _fixed_padding(kernel_size, dilation):
|
|
||||||
kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
|
|
||||||
pad_total = kernel_size_effective - 1
|
|
||||||
pad_beg = pad_total // 2
|
|
||||||
pad_end = pad_total - pad_beg
|
|
||||||
return [pad_beg, pad_end, pad_beg, pad_end]
|
|
||||||
|
|
||||||
|
|
||||||
class SeparableConv2d(nn.Module):
|
class SeparableConv2d(nn.Module):
|
||||||
@ -88,24 +49,16 @@ class SeparableConv2d(nn.Module):
|
|||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.dilation = dilation
|
self.dilation = dilation
|
||||||
|
|
||||||
padding = _fixed_padding(self.kernel_size, self.dilation)
|
|
||||||
if _USE_FIXED_PAD and any(p > 0 for p in padding):
|
|
||||||
self.fixed_padding = nn.ZeroPad2d(padding)
|
|
||||||
else:
|
|
||||||
self.fixed_padding = None
|
|
||||||
|
|
||||||
# depthwise convolution
|
# depthwise convolution
|
||||||
|
padding = get_padding(kernel_size, stride, dilation)
|
||||||
self.conv_dw = nn.Conv2d(
|
self.conv_dw = nn.Conv2d(
|
||||||
inplanes, inplanes, kernel_size, stride=stride,
|
inplanes, inplanes, kernel_size, stride=stride,
|
||||||
padding=_pytorch_padding(kernel_size, stride, dilation), dilation=dilation, groups=inplanes, bias=bias)
|
padding=padding, dilation=dilation, groups=inplanes, bias=bias)
|
||||||
self.bn = norm_layer(num_features=inplanes, **norm_kwargs)
|
self.bn = norm_layer(num_features=inplanes, **norm_kwargs)
|
||||||
# pointwise convolution
|
# pointwise convolution
|
||||||
self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias)
|
self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.fixed_padding is not None:
|
|
||||||
# FIXME remove once verified
|
|
||||||
x = self.fixed_padding(x)
|
|
||||||
x = self.conv_dw(x)
|
x = self.conv_dw(x)
|
||||||
x = self.bn(x)
|
x = self.bn(x)
|
||||||
x = self.conv_pw(x)
|
x = self.conv_pw(x)
|
||||||
@ -113,58 +66,37 @@ class SeparableConv2d(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
def __init__(self, inplanes, planes, num_reps, stride=1, dilation=1, norm_layer=None,
|
def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True,
|
||||||
norm_kwargs=None, start_with_relu=True, grow_first=True, is_last=False):
|
norm_layer=None, norm_kwargs=None, ):
|
||||||
super(Block, self).__init__()
|
super(Block, self).__init__()
|
||||||
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
||||||
if planes != inplanes or stride != 1:
|
if isinstance(planes, (list, tuple)):
|
||||||
|
assert len(planes) == 3
|
||||||
|
else:
|
||||||
|
planes = (planes,) * 3
|
||||||
|
outplanes = planes[-1]
|
||||||
|
|
||||||
|
if outplanes != inplanes or stride != 1:
|
||||||
self.skip = nn.Sequential()
|
self.skip = nn.Sequential()
|
||||||
self.skip.add_module('conv1', nn.Conv2d(
|
self.skip.add_module('conv1', nn.Conv2d(
|
||||||
inplanes, planes, 1, stride=stride, bias=False)),
|
inplanes, outplanes, 1, stride=stride, bias=False)),
|
||||||
self.skip.add_module('bn1', norm_layer(num_features=planes, **norm_kwargs))
|
self.skip.add_module('bn1', norm_layer(num_features=outplanes, **norm_kwargs))
|
||||||
else:
|
else:
|
||||||
self.skip = None
|
self.skip = None
|
||||||
|
|
||||||
rep = OrderedDict()
|
rep = OrderedDict()
|
||||||
l = 1
|
for i in range(3):
|
||||||
filters = inplanes
|
rep['act%d' % (i + 1)] = nn.ReLU(inplace=True)
|
||||||
if grow_first:
|
rep['conv%d' % (i + 1)] = SeparableConv2d(
|
||||||
if start_with_relu:
|
inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation,
|
||||||
rep['act%d' % l] = nn.ReLU(inplace=False) # NOTE: silent failure if inplace=True here
|
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||||
rep['conv%d' % l] = SeparableConv2d(
|
rep['bn%d' % (i + 1)] = norm_layer(planes[i], **norm_kwargs)
|
||||||
inplanes, planes, 3, 1, dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
inplanes = planes[i]
|
||||||
rep['bn%d' % l] = norm_layer(num_features=planes, **norm_kwargs)
|
|
||||||
filters = planes
|
|
||||||
l += 1
|
|
||||||
|
|
||||||
for _ in range(num_reps - 1):
|
if not start_with_relu:
|
||||||
if grow_first or start_with_relu:
|
del rep['act1']
|
||||||
# FIXME being conservative with inplace here, think it's fine to leave True?
|
else:
|
||||||
rep['act%d' % l] = nn.ReLU(inplace=grow_first or not start_with_relu)
|
rep['act1'] = nn.ReLU(inplace=False)
|
||||||
rep['conv%d' % l] = SeparableConv2d(
|
|
||||||
filters, filters, 3, 1, dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
|
||||||
rep['bn%d' % l] = norm_layer(num_features=filters, **norm_kwargs)
|
|
||||||
l += 1
|
|
||||||
|
|
||||||
if not grow_first:
|
|
||||||
rep['act%d' % l] = nn.ReLU(inplace=True)
|
|
||||||
rep['conv%d' % l] = SeparableConv2d(
|
|
||||||
inplanes, planes, 3, 1, dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
|
||||||
rep['bn%d' % l] = norm_layer(num_features=planes, **norm_kwargs)
|
|
||||||
l += 1
|
|
||||||
|
|
||||||
if stride != 1:
|
|
||||||
rep['act%d' % l] = nn.ReLU(inplace=True)
|
|
||||||
rep['conv%d' % l] = SeparableConv2d(
|
|
||||||
planes, planes, 3, stride, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
|
||||||
rep['bn%d' % l] = norm_layer(num_features=planes, **norm_kwargs)
|
|
||||||
l += 1
|
|
||||||
elif is_last:
|
|
||||||
rep['act%d' % l] = nn.ReLU(inplace=True)
|
|
||||||
rep['conv%d' % l] = SeparableConv2d(
|
|
||||||
planes, planes, 3, 1, dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
|
||||||
rep['bn%d' % l] = norm_layer(num_features=planes, **norm_kwargs)
|
|
||||||
l += 1
|
|
||||||
self.rep = nn.Sequential(rep)
|
self.rep = nn.Sequential(rep)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -176,7 +108,10 @@ class Block(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Xception65(nn.Module):
|
class Xception65(nn.Module):
|
||||||
"""Modified Aligned Xception
|
"""Modified Aligned Xception.
|
||||||
|
|
||||||
|
NOTE: only the 65 layer version is included here, the 71 layer variant
|
||||||
|
was not correct and had no pretrained weights
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
|
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
|
||||||
@ -212,25 +147,21 @@ class Xception65(nn.Module):
|
|||||||
self.bn2 = norm_layer(num_features=64)
|
self.bn2 = norm_layer(num_features=64)
|
||||||
|
|
||||||
self.block1 = Block(
|
self.block1 = Block(
|
||||||
64, 128, num_reps=2, stride=2,
|
64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=False)
|
|
||||||
self.block2 = Block(
|
self.block2 = Block(
|
||||||
128, 256, num_reps=2, stride=2,
|
128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=False, grow_first=True)
|
|
||||||
self.block3 = Block(
|
self.block3 = Block(
|
||||||
256, 728, num_reps=2, stride=entry_block3_stride,
|
256, 728, stride=entry_block3_stride, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True, is_last=True)
|
|
||||||
|
|
||||||
# Middle flow
|
# Middle flow
|
||||||
self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block(
|
self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block(
|
||||||
728, 728, num_reps=3, stride=1, dilation=middle_block_dilation,
|
728, 728, stride=1, dilation=middle_block_dilation,
|
||||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True))
|
norm_layer=norm_layer, norm_kwargs=norm_kwargs)) for i in range(4, 20)]))
|
||||||
for i in range(4, 20)]))
|
|
||||||
|
|
||||||
# Exit flow
|
# Exit flow
|
||||||
self.block20 = Block(
|
self.block20 = Block(
|
||||||
728, 1024, num_reps=2, stride=exit_block20_stride, dilation=exit_block_dilations[0],
|
728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_block_dilations[0],
|
||||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=False, is_last=True)
|
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||||
|
|
||||||
self.conv3 = SeparableConv2d(
|
self.conv3 = SeparableConv2d(
|
||||||
1024, 1536, 3, stride=1, dilation=exit_block_dilations[1],
|
1024, 1536, 3, stride=1, dilation=exit_block_dilations[1],
|
||||||
@ -305,147 +236,6 @@ class Xception65(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Xception71(nn.Module):
|
|
||||||
"""Modified Aligned Xception
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
|
|
||||||
norm_kwargs=None, drop_rate=0., global_pool='avg'):
|
|
||||||
super(Xception71, self).__init__()
|
|
||||||
self.num_classes = num_classes
|
|
||||||
self.drop_rate = drop_rate
|
|
||||||
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
|
||||||
if output_stride == 32:
|
|
||||||
entry_block3_stride = 2
|
|
||||||
exit_block20_stride = 2
|
|
||||||
middle_block_dilation = 1
|
|
||||||
exit_block_dilations = (1, 1)
|
|
||||||
elif output_stride == 16:
|
|
||||||
entry_block3_stride = 2
|
|
||||||
exit_block20_stride = 1
|
|
||||||
middle_block_dilation = 1
|
|
||||||
exit_block_dilations = (1, 2)
|
|
||||||
elif output_stride == 8:
|
|
||||||
entry_block3_stride = 1
|
|
||||||
exit_block20_stride = 1
|
|
||||||
middle_block_dilation = 2
|
|
||||||
exit_block_dilations = (2, 4)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
# Entry flow
|
|
||||||
self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False)
|
|
||||||
self.bn1 = norm_layer(num_features=32, **norm_kwargs)
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
|
|
||||||
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
|
||||||
self.bn2 = norm_layer(num_features=64)
|
|
||||||
|
|
||||||
self.block1 = Block(
|
|
||||||
64, 128, num_reps=2, stride=2, norm_layer=norm_layer,
|
|
||||||
norm_kwargs=norm_kwargs, start_with_relu=False)
|
|
||||||
self.block2 = nn.Sequential(*[
|
|
||||||
Block(
|
|
||||||
128, 256, num_reps=2, stride=1, norm_layer=norm_layer,
|
|
||||||
norm_kwargs=norm_kwargs, start_with_relu=False, grow_first=True),
|
|
||||||
Block(
|
|
||||||
256, 256, num_reps=2, stride=2, norm_layer=norm_layer,
|
|
||||||
norm_kwargs=norm_kwargs, start_with_relu=False, grow_first=True),
|
|
||||||
Block(
|
|
||||||
256, 728, num_reps=2, stride=2, norm_layer=norm_layer,
|
|
||||||
norm_kwargs=norm_kwargs, start_with_relu=False, grow_first=True)])
|
|
||||||
self.block3 = Block(
|
|
||||||
728, 728, num_reps=2, stride=entry_block3_stride, norm_layer=norm_layer,
|
|
||||||
norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True, is_last=True)
|
|
||||||
|
|
||||||
# Middle flow
|
|
||||||
self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block(
|
|
||||||
728, 728, num_reps=3, stride=1, dilation=middle_block_dilation,
|
|
||||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True))
|
|
||||||
for i in range(4, 20)]))
|
|
||||||
|
|
||||||
# Exit flow
|
|
||||||
self.block20 = Block(
|
|
||||||
728, 1024, num_reps=2, stride=exit_block20_stride, dilation=exit_block_dilations[0],
|
|
||||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=False, is_last=True)
|
|
||||||
|
|
||||||
self.conv3 = SeparableConv2d(
|
|
||||||
1024, 1536, 3, stride=1, dilation=exit_block_dilations[1],
|
|
||||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
|
||||||
self.bn3 = norm_layer(num_features=1536, **norm_kwargs)
|
|
||||||
|
|
||||||
self.conv4 = SeparableConv2d(
|
|
||||||
1536, 1536, 3, stride=1, dilation=exit_block_dilations[1],
|
|
||||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
|
||||||
self.bn4 = norm_layer(num_features=1536, **norm_kwargs)
|
|
||||||
|
|
||||||
self.num_features = 2048
|
|
||||||
self.conv5 = SeparableConv2d(
|
|
||||||
1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1],
|
|
||||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
|
||||||
self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs)
|
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
||||||
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
|
||||||
|
|
||||||
def get_classifier(self):
|
|
||||||
return self.fc
|
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
|
||||||
self.num_classes = num_classes
|
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
||||||
if num_classes:
|
|
||||||
num_features = self.num_features * self.global_pool.feat_mult()
|
|
||||||
self.fc = nn.Linear(num_features, num_classes)
|
|
||||||
else:
|
|
||||||
self.fc = nn.Identity()
|
|
||||||
|
|
||||||
def forward_features(self, x):
|
|
||||||
# Entry flow
|
|
||||||
x = self.conv1(x)
|
|
||||||
x = self.bn1(x)
|
|
||||||
x = self.relu(x)
|
|
||||||
|
|
||||||
x = self.conv2(x)
|
|
||||||
x = self.bn2(x)
|
|
||||||
x = self.relu(x)
|
|
||||||
|
|
||||||
x = self.block1(x)
|
|
||||||
# add relu here
|
|
||||||
x = self.relu(x)
|
|
||||||
# low_level_feat = x
|
|
||||||
x = self.block2(x)
|
|
||||||
# c2 = x
|
|
||||||
x = self.block3(x)
|
|
||||||
|
|
||||||
# Middle flow
|
|
||||||
x = self.mid(x)
|
|
||||||
# c3 = x
|
|
||||||
|
|
||||||
# Exit flow
|
|
||||||
x = self.block20(x)
|
|
||||||
x = self.relu(x)
|
|
||||||
x = self.conv3(x)
|
|
||||||
x = self.bn3(x)
|
|
||||||
x = self.relu(x)
|
|
||||||
|
|
||||||
x = self.conv4(x)
|
|
||||||
x = self.bn4(x)
|
|
||||||
x = self.relu(x)
|
|
||||||
|
|
||||||
x = self.conv5(x)
|
|
||||||
x = self.bn5(x)
|
|
||||||
x = self.relu(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.forward_features(x)
|
|
||||||
x = self.global_pool(x).flatten(1)
|
|
||||||
if self.drop_rate:
|
|
||||||
F.dropout(x, self.drop_rate, training=self.training)
|
|
||||||
x = self.fc(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def gluon_xception65(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def gluon_xception65(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
""" Modified Aligned Xception-65
|
""" Modified Aligned Xception-65
|
||||||
@ -456,15 +246,3 @@ def gluon_xception65(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def gluon_xception71(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
||||||
""" Modified Aligned Xception-71
|
|
||||||
"""
|
|
||||||
default_cfg = default_cfgs['gluon_xception71']
|
|
||||||
model = Xception71(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
@ -7,6 +7,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
|
from .features import FeatureNet
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .layers import SelectAdaptivePool2d
|
from .layers import SelectAdaptivePool2d
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
@ -231,9 +232,13 @@ class InceptionResnetV2(nn.Module):
|
|||||||
self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2)
|
self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2)
|
||||||
self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
|
self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
|
||||||
self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
|
self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.feature_info = [dict(num_chs=64, reduction=2, module='conv2d_2b')]
|
||||||
|
|
||||||
self.maxpool_3a = nn.MaxPool2d(3, stride=2)
|
self.maxpool_3a = nn.MaxPool2d(3, stride=2)
|
||||||
self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
|
self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
|
||||||
self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
|
self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
|
||||||
|
self.feature_info += [dict(num_chs=192, reduction=4, module='conv2d_4a')]
|
||||||
|
|
||||||
self.maxpool_5a = nn.MaxPool2d(3, stride=2)
|
self.maxpool_5a = nn.MaxPool2d(3, stride=2)
|
||||||
self.mixed_5b = Mixed_5b()
|
self.mixed_5b = Mixed_5b()
|
||||||
self.repeat = nn.Sequential(
|
self.repeat = nn.Sequential(
|
||||||
@ -248,6 +253,8 @@ class InceptionResnetV2(nn.Module):
|
|||||||
Block35(scale=0.17),
|
Block35(scale=0.17),
|
||||||
Block35(scale=0.17)
|
Block35(scale=0.17)
|
||||||
)
|
)
|
||||||
|
self.feature_info += [dict(num_chs=320, reduction=8, module='repeat')]
|
||||||
|
|
||||||
self.mixed_6a = Mixed_6a()
|
self.mixed_6a = Mixed_6a()
|
||||||
self.repeat_1 = nn.Sequential(
|
self.repeat_1 = nn.Sequential(
|
||||||
Block17(scale=0.10),
|
Block17(scale=0.10),
|
||||||
@ -271,6 +278,8 @@ class InceptionResnetV2(nn.Module):
|
|||||||
Block17(scale=0.10),
|
Block17(scale=0.10),
|
||||||
Block17(scale=0.10)
|
Block17(scale=0.10)
|
||||||
)
|
)
|
||||||
|
self.feature_info += [dict(num_chs=1088, reduction=16, module='repeat_1')]
|
||||||
|
|
||||||
self.mixed_7a = Mixed_7a()
|
self.mixed_7a = Mixed_7a()
|
||||||
self.repeat_2 = nn.Sequential(
|
self.repeat_2 = nn.Sequential(
|
||||||
Block8(scale=0.20),
|
Block8(scale=0.20),
|
||||||
@ -285,6 +294,8 @@ class InceptionResnetV2(nn.Module):
|
|||||||
)
|
)
|
||||||
self.block8 = Block8(no_relu=True)
|
self.block8 = Block8(no_relu=True)
|
||||||
self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
|
self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
|
||||||
|
self.feature_info += [dict(num_chs=self.num_features, reduction=32, module='conv2d_7b')]
|
||||||
|
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
# NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC
|
# NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC
|
||||||
self.classif = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
self.classif = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||||
@ -328,30 +339,34 @@ class InceptionResnetV2(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
def _inception_resnet_v2(variant, pretrained=False, **kwargs):
|
||||||
def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
load_strict, features, out_indices = True, False, None
|
||||||
r"""InceptionResnetV2 model architecture from the
|
if kwargs.pop('features_only', False):
|
||||||
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>` paper.
|
load_strict, features, out_indices = False, True, kwargs.pop('out_indices', (0, 1, 2, 3, 4))
|
||||||
"""
|
kwargs.pop('num_classes', 0)
|
||||||
default_cfg = default_cfgs['inception_resnet_v2']
|
model = InceptionResnetV2(**kwargs)
|
||||||
model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
model.default_cfg = default_cfgs[variant]
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
load_pretrained(
|
||||||
|
model,
|
||||||
|
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict)
|
||||||
|
if features:
|
||||||
|
model = FeatureNet(model, out_indices)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def ens_adv_inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def inception_resnet_v2(pretrained=False, **kwargs):
|
||||||
|
r"""InceptionResnetV2 model architecture from the
|
||||||
|
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>` paper.
|
||||||
|
"""
|
||||||
|
return _inception_resnet_v2('inception_resnet_v2', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def ens_adv_inception_resnet_v2(pretrained=False, **kwargs):
|
||||||
r""" Ensemble Adversarially trained InceptionResnetV2 model architecture
|
r""" Ensemble Adversarially trained InceptionResnetV2 model architecture
|
||||||
As per https://arxiv.org/abs/1705.07204 and
|
As per https://arxiv.org/abs/1705.07204 and
|
||||||
https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models.
|
https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['ens_adv_inception_resnet_v2']
|
return _inception_resnet_v2('ens_adv_inception_resnet_v2', pretrained=pretrained, **kwargs)
|
||||||
model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
@ -17,6 +17,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
|
|||||||
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||||
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
||||||
from .feature_hooks import FeatureHooks
|
from .feature_hooks import FeatureHooks
|
||||||
|
from .features import FeatureInfo
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .layers import SelectAdaptivePool2d, create_conv2d, get_act_fn, hard_sigmoid
|
from .layers import SelectAdaptivePool2d, create_conv2d, get_act_fn, hard_sigmoid
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
@ -182,22 +183,20 @@ class MobileNetV3Features(nn.Module):
|
|||||||
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
|
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
|
||||||
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||||
self._feature_info = builder.features # builder provides info about feature channels for each block
|
self.feature_info = FeatureInfo(builder.features, out_indices)
|
||||||
self._stage_to_feature_idx = {
|
self._stage_to_feature_idx = {
|
||||||
v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices}
|
v['stage_idx']: fi for fi, v in enumerate(self.feature_info) if fi in self.out_indices}
|
||||||
self._in_chs = builder.in_chs
|
self._in_chs = builder.in_chs
|
||||||
|
|
||||||
efficientnet_init_weights(self)
|
efficientnet_init_weights(self)
|
||||||
if _DEBUG:
|
if _DEBUG:
|
||||||
for k, v in self._feature_info.items():
|
for fi, v in enumerate(self.feature_info):
|
||||||
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
|
print('Feature idx: {}: Name: {}, Channels: {}'.format(fi, v['module'], v['num_chs']))
|
||||||
|
|
||||||
# Register feature extraction hooks with FeatureHooks helper
|
# Register feature extraction hooks with FeatureHooks helper
|
||||||
self.feature_hooks = None
|
self.feature_hooks = None
|
||||||
if feature_location != 'bottleneck':
|
if feature_location != 'bottleneck':
|
||||||
hooks = [dict(
|
hooks = self.feature_info.get_by_key(keys=('module', 'hook_type'))
|
||||||
name=self._feature_info[idx]['module'],
|
|
||||||
type=self._feature_info[idx]['hook_type']) for idx in out_indices]
|
|
||||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||||
|
|
||||||
def feature_channels(self, idx=None):
|
def feature_channels(self, idx=None):
|
||||||
@ -206,17 +205,8 @@ class MobileNetV3Features(nn.Module):
|
|||||||
return feature channel count for that feature block index (independent of out_indices setting).
|
return feature channel count for that feature block index (independent of out_indices setting).
|
||||||
"""
|
"""
|
||||||
if isinstance(idx, int):
|
if isinstance(idx, int):
|
||||||
return self._feature_info[idx]['num_chs']
|
return self.feature_info[idx]['num_chs']
|
||||||
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
|
return [self.feature_info[i]['num_chs'] for i in self.out_indices]
|
||||||
|
|
||||||
def feature_info(self, idx=None):
|
|
||||||
""" Feature Channel Shortcut
|
|
||||||
Returns feature channel count for each output index if idx == None. If idx is an integer, will
|
|
||||||
return feature channel count for that feature block index (independent of out_indices setting).
|
|
||||||
"""
|
|
||||||
if isinstance(idx, int):
|
|
||||||
return self._feature_info[idx]
|
|
||||||
return [self._feature_info[i] for i in self.out_indices]
|
|
||||||
|
|
||||||
def forward(self, x) -> List[torch.Tensor]:
|
def forward(self, x) -> List[torch.Tensor]:
|
||||||
x = self.conv_stem(x)
|
x = self.conv_stem(x)
|
||||||
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .layers import SelectAdaptivePool2d
|
from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d, create_pool2d
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
|
|
||||||
__all__ = ['NASNetALarge']
|
__all__ = ['NASNetALarge']
|
||||||
@ -24,43 +24,31 @@ default_cfgs = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class MaxPoolPad(nn.Module):
|
class ActConvBn(nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''):
|
||||||
super(MaxPoolPad, self).__init__()
|
super(ActConvBn, self).__init__()
|
||||||
self.pad = nn.ZeroPad2d((1, 0, 1, 0))
|
self.act = nn.ReLU()
|
||||||
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
|
self.conv = create_conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||||
|
self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.pad(x)
|
x = self.act(x)
|
||||||
x = self.pool(x)
|
x = self.conv(x)
|
||||||
x = x[:, :, 1:, 1:]
|
x = self.bn(x)
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class AvgPoolPad(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, stride=2, padding=1):
|
|
||||||
super(AvgPoolPad, self).__init__()
|
|
||||||
self.pad = nn.ZeroPad2d((1, 0, 1, 0))
|
|
||||||
self.pool = nn.AvgPool2d(3, stride=stride, padding=padding, count_include_pad=False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.pad(x)
|
|
||||||
x = self.pool(x)
|
|
||||||
x = x[:, :, 1:, 1:]
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SeparableConv2d(nn.Module):
|
class SeparableConv2d(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, dw_kernel, dw_stride, dw_padding, bias=False):
|
def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''):
|
||||||
super(SeparableConv2d, self).__init__()
|
super(SeparableConv2d, self).__init__()
|
||||||
self.depthwise_conv2d = nn.Conv2d(
|
self.depthwise_conv2d = create_conv2d(
|
||||||
in_channels, in_channels, dw_kernel,
|
in_channels, in_channels, kernel_size=kernel_size,
|
||||||
stride=dw_stride, padding=dw_padding,
|
stride=stride, padding=padding, groups=in_channels)
|
||||||
bias=bias, groups=in_channels)
|
self.pointwise_conv2d = create_conv2d(
|
||||||
self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels, 1, stride=1, bias=bias)
|
in_channels, out_channels, kernel_size=1, padding=0)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.depthwise_conv2d(x)
|
x = self.depthwise_conv2d(x)
|
||||||
@ -70,87 +58,48 @@ class SeparableConv2d(nn.Module):
|
|||||||
|
|
||||||
class BranchSeparables(nn.Module):
|
class BranchSeparables(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False):
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_type='', stem_cell=False):
|
||||||
super(BranchSeparables, self).__init__()
|
super(BranchSeparables, self).__init__()
|
||||||
self.relu = nn.ReLU()
|
middle_channels = out_channels if stem_cell else in_channels
|
||||||
self.separable_1 = SeparableConv2d(in_channels, in_channels, kernel_size, stride, padding, bias=bias)
|
self.act_1 = nn.ReLU()
|
||||||
self.bn_sep_1 = nn.BatchNorm2d(in_channels, eps=0.001, momentum=0.1, affine=True)
|
self.separable_1 = SeparableConv2d(
|
||||||
self.relu1 = nn.ReLU()
|
in_channels, middle_channels, kernel_size, stride=stride, padding=pad_type)
|
||||||
self.separable_2 = SeparableConv2d(in_channels, out_channels, kernel_size, 1, padding, bias=bias)
|
self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001, momentum=0.1)
|
||||||
self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
|
self.act_2 = nn.ReLU(inplace=True)
|
||||||
|
self.separable_2 = SeparableConv2d(
|
||||||
|
middle_channels, out_channels, kernel_size, stride=1, padding=pad_type)
|
||||||
|
self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.relu(x)
|
x = self.act_1(x)
|
||||||
x = self.separable_1(x)
|
x = self.separable_1(x)
|
||||||
x = self.bn_sep_1(x)
|
x = self.bn_sep_1(x)
|
||||||
x = self.relu1(x)
|
x = self.act_2(x)
|
||||||
x = self.separable_2(x)
|
|
||||||
x = self.bn_sep_2(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class BranchSeparablesStem(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False):
|
|
||||||
super(BranchSeparablesStem, self).__init__()
|
|
||||||
self.relu = nn.ReLU()
|
|
||||||
self.separable_1 = SeparableConv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
|
|
||||||
self.bn_sep_1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
|
|
||||||
self.relu1 = nn.ReLU()
|
|
||||||
self.separable_2 = SeparableConv2d(out_channels, out_channels, kernel_size, 1, padding, bias=bias)
|
|
||||||
self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.relu(x)
|
|
||||||
x = self.separable_1(x)
|
|
||||||
x = self.bn_sep_1(x)
|
|
||||||
x = self.relu1(x)
|
|
||||||
x = self.separable_2(x)
|
|
||||||
x = self.bn_sep_2(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class BranchSeparablesReduction(BranchSeparables):
|
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, z_padding=1, bias=False):
|
|
||||||
BranchSeparables.__init__(self, in_channels, out_channels, kernel_size, stride, padding, bias)
|
|
||||||
self.padding = nn.ZeroPad2d((z_padding, 0, z_padding, 0))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.relu(x)
|
|
||||||
x = self.padding(x)
|
|
||||||
x = self.separable_1(x)
|
|
||||||
x = x[:, :, 1:, 1:].contiguous()
|
|
||||||
x = self.bn_sep_1(x)
|
|
||||||
x = self.relu1(x)
|
|
||||||
x = self.separable_2(x)
|
x = self.separable_2(x)
|
||||||
x = self.bn_sep_2(x)
|
x = self.bn_sep_2(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class CellStem0(nn.Module):
|
class CellStem0(nn.Module):
|
||||||
def __init__(self, stem_size, num_channels=42):
|
def __init__(self, stem_size, num_channels=42, pad_type=''):
|
||||||
super(CellStem0, self).__init__()
|
super(CellStem0, self).__init__()
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
self.stem_size = stem_size
|
self.stem_size = stem_size
|
||||||
self.conv_1x1 = nn.Sequential()
|
self.conv_1x1 = ActConvBn(self.stem_size, self.num_channels, 1, stride=1)
|
||||||
self.conv_1x1.add_module('relu', nn.ReLU())
|
|
||||||
self.conv_1x1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels, 1, stride=1, bias=False))
|
|
||||||
self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True))
|
|
||||||
|
|
||||||
self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, 2)
|
self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type)
|
||||||
self.comb_iter_0_right = BranchSeparablesStem(self.stem_size, self.num_channels, 7, 2, 3, bias=False)
|
self.comb_iter_0_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True)
|
||||||
|
|
||||||
self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
|
self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
|
||||||
self.comb_iter_1_right = BranchSeparablesStem(self.stem_size, self.num_channels, 7, 2, 3, bias=False)
|
self.comb_iter_1_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True)
|
||||||
|
|
||||||
self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
|
self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
|
||||||
self.comb_iter_2_right = BranchSeparablesStem(self.stem_size, self.num_channels, 5, 2, 2, bias=False)
|
self.comb_iter_2_right = BranchSeparables(self.stem_size, self.num_channels, 5, 2, pad_type, stem_cell=True)
|
||||||
|
|
||||||
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
||||||
|
|
||||||
self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, 1, bias=False)
|
self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type)
|
||||||
self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
|
self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x1 = self.conv_1x1(x)
|
x1 = self.conv_1x1(x)
|
||||||
@ -180,51 +129,46 @@ class CellStem0(nn.Module):
|
|||||||
|
|
||||||
class CellStem1(nn.Module):
|
class CellStem1(nn.Module):
|
||||||
|
|
||||||
def __init__(self, stem_size, num_channels):
|
def __init__(self, stem_size, num_channels, pad_type=''):
|
||||||
super(CellStem1, self).__init__()
|
super(CellStem1, self).__init__()
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
self.stem_size = stem_size
|
self.stem_size = stem_size
|
||||||
self.conv_1x1 = nn.Sequential()
|
self.conv_1x1 = ActConvBn(2 * self.num_channels, self.num_channels, 1, stride=1)
|
||||||
self.conv_1x1.add_module('relu', nn.ReLU())
|
|
||||||
self.conv_1x1.add_module('conv', nn.Conv2d(2 * self.num_channels, self.num_channels, 1, stride=1, bias=False))
|
|
||||||
self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True))
|
|
||||||
|
|
||||||
self.relu = nn.ReLU()
|
self.act = nn.ReLU()
|
||||||
self.path_1 = nn.Sequential()
|
self.path_1 = nn.Sequential()
|
||||||
self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
||||||
self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False))
|
self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False))
|
||||||
self.path_2 = nn.ModuleList()
|
|
||||||
self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1)))
|
self.path_2 = nn.Sequential()
|
||||||
|
self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1)))
|
||||||
self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
||||||
self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False))
|
self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False))
|
||||||
|
|
||||||
self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1, affine=True)
|
self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1)
|
||||||
|
|
||||||
self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, 2, bias=False)
|
self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type)
|
||||||
self.comb_iter_0_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, 3, bias=False)
|
self.comb_iter_0_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type)
|
||||||
|
|
||||||
self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
|
self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
|
||||||
self.comb_iter_1_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, 3, bias=False)
|
self.comb_iter_1_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type)
|
||||||
|
|
||||||
self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
|
self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
|
||||||
self.comb_iter_2_right = BranchSeparables(self.num_channels, self.num_channels, 5, 2, 2, bias=False)
|
self.comb_iter_2_right = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type)
|
||||||
|
|
||||||
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
||||||
|
|
||||||
self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, 1, bias=False)
|
self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type)
|
||||||
self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
|
self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
|
||||||
|
|
||||||
def forward(self, x_conv0, x_stem_0):
|
def forward(self, x_conv0, x_stem_0):
|
||||||
x_left = self.conv_1x1(x_stem_0)
|
x_left = self.conv_1x1(x_stem_0)
|
||||||
|
|
||||||
x_relu = self.relu(x_conv0)
|
x_relu = self.act(x_conv0)
|
||||||
# path 1
|
# path 1
|
||||||
x_path1 = self.path_1(x_relu)
|
x_path1 = self.path_1(x_relu)
|
||||||
# path 2
|
# path 2
|
||||||
x_path2 = self.path_2.pad(x_relu)
|
x_path2 = self.path_2(x_relu)
|
||||||
x_path2 = x_path2[:, :, 1:, 1:]
|
|
||||||
x_path2 = self.path_2.avgpool(x_path2)
|
|
||||||
x_path2 = self.path_2.conv(x_path2)
|
|
||||||
# final path
|
# final path
|
||||||
x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
|
x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
|
||||||
|
|
||||||
@ -253,49 +197,40 @@ class CellStem1(nn.Module):
|
|||||||
|
|
||||||
class FirstCell(nn.Module):
|
class FirstCell(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
|
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
|
||||||
super(FirstCell, self).__init__()
|
super(FirstCell, self).__init__()
|
||||||
self.conv_1x1 = nn.Sequential()
|
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1)
|
||||||
self.conv_1x1.add_module('relu', nn.ReLU())
|
|
||||||
self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
|
|
||||||
self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
|
|
||||||
|
|
||||||
self.relu = nn.ReLU()
|
self.act = nn.ReLU()
|
||||||
self.path_1 = nn.Sequential()
|
self.path_1 = nn.Sequential()
|
||||||
self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
||||||
self.path_1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
|
self.path_1.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False))
|
||||||
self.path_2 = nn.ModuleList()
|
|
||||||
self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1)))
|
self.path_2 = nn.Sequential()
|
||||||
|
self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1)))
|
||||||
self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
||||||
self.path_2.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
|
self.path_2.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False))
|
||||||
|
|
||||||
self.final_path_bn = nn.BatchNorm2d(out_channels_left * 2, eps=0.001, momentum=0.1, affine=True)
|
self.final_path_bn = nn.BatchNorm2d(out_chs_left * 2, eps=0.001, momentum=0.1)
|
||||||
|
|
||||||
self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
|
self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type)
|
||||||
self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
|
self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
|
||||||
|
|
||||||
self.comb_iter_1_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
|
self.comb_iter_1_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type)
|
||||||
self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
|
self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
|
||||||
|
|
||||||
self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
||||||
|
|
||||||
self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
||||||
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
||||||
|
|
||||||
self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
|
self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
|
||||||
|
|
||||||
def forward(self, x, x_prev):
|
def forward(self, x, x_prev):
|
||||||
x_relu = self.relu(x_prev)
|
x_relu = self.act(x_prev)
|
||||||
# path 1
|
|
||||||
x_path1 = self.path_1(x_relu)
|
x_path1 = self.path_1(x_relu)
|
||||||
# path 2
|
x_path2 = self.path_2(x_relu)
|
||||||
x_path2 = self.path_2.pad(x_relu)
|
|
||||||
x_path2 = x_path2[:, :, 1:, 1:]
|
|
||||||
x_path2 = self.path_2.avgpool(x_path2)
|
|
||||||
x_path2 = self.path_2.conv(x_path2)
|
|
||||||
# final path
|
|
||||||
x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
|
x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
|
||||||
|
|
||||||
x_right = self.conv_1x1(x)
|
x_right = self.conv_1x1(x)
|
||||||
|
|
||||||
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
|
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
|
||||||
@ -322,30 +257,23 @@ class FirstCell(nn.Module):
|
|||||||
|
|
||||||
class NormalCell(nn.Module):
|
class NormalCell(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
|
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
|
||||||
super(NormalCell, self).__init__()
|
super(NormalCell, self).__init__()
|
||||||
self.conv_prev_1x1 = nn.Sequential()
|
self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type)
|
||||||
self.conv_prev_1x1.add_module('relu', nn.ReLU())
|
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type)
|
||||||
self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
|
|
||||||
self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
|
|
||||||
|
|
||||||
self.conv_1x1 = nn.Sequential()
|
self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type)
|
||||||
self.conv_1x1.add_module('relu', nn.ReLU())
|
self.comb_iter_0_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type)
|
||||||
self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
|
|
||||||
self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
|
|
||||||
|
|
||||||
self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
|
self.comb_iter_1_left = BranchSeparables(out_chs_left, out_chs_left, 5, 1, pad_type)
|
||||||
self.comb_iter_0_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False)
|
self.comb_iter_1_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type)
|
||||||
|
|
||||||
self.comb_iter_1_left = BranchSeparables(out_channels_left, out_channels_left, 5, 1, 2, bias=False)
|
self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
||||||
self.comb_iter_1_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False)
|
|
||||||
|
|
||||||
self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
||||||
|
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
||||||
|
|
||||||
self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
|
||||||
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
|
||||||
|
|
||||||
self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
|
|
||||||
|
|
||||||
def forward(self, x, x_prev):
|
def forward(self, x, x_prev):
|
||||||
x_left = self.conv_prev_1x1(x_prev)
|
x_left = self.conv_prev_1x1(x_prev)
|
||||||
@ -375,31 +303,24 @@ class NormalCell(nn.Module):
|
|||||||
|
|
||||||
class ReductionCell0(nn.Module):
|
class ReductionCell0(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
|
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
|
||||||
super(ReductionCell0, self).__init__()
|
super(ReductionCell0, self).__init__()
|
||||||
self.conv_prev_1x1 = nn.Sequential()
|
self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type)
|
||||||
self.conv_prev_1x1.add_module('relu', nn.ReLU())
|
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type)
|
||||||
self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
|
|
||||||
self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
|
|
||||||
|
|
||||||
self.conv_1x1 = nn.Sequential()
|
self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
|
||||||
self.conv_1x1.add_module('relu', nn.ReLU())
|
self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
|
||||||
self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
|
|
||||||
self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
|
|
||||||
|
|
||||||
self.comb_iter_0_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
|
self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
|
||||||
self.comb_iter_0_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
|
self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
|
||||||
|
|
||||||
self.comb_iter_1_left = MaxPoolPad()
|
self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
|
||||||
self.comb_iter_1_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
|
self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
|
||||||
|
|
||||||
self.comb_iter_2_left = AvgPoolPad()
|
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
||||||
self.comb_iter_2_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
|
|
||||||
|
|
||||||
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
|
||||||
|
self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
|
||||||
self.comb_iter_4_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
|
|
||||||
self.comb_iter_4_right = MaxPoolPad()
|
|
||||||
|
|
||||||
def forward(self, x, x_prev):
|
def forward(self, x, x_prev):
|
||||||
x_left = self.conv_prev_1x1(x_prev)
|
x_left = self.conv_prev_1x1(x_prev)
|
||||||
@ -430,31 +351,24 @@ class ReductionCell0(nn.Module):
|
|||||||
|
|
||||||
class ReductionCell1(nn.Module):
|
class ReductionCell1(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
|
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
|
||||||
super(ReductionCell1, self).__init__()
|
super(ReductionCell1, self).__init__()
|
||||||
self.conv_prev_1x1 = nn.Sequential()
|
self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type)
|
||||||
self.conv_prev_1x1.add_module('relu', nn.ReLU())
|
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type)
|
||||||
self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
|
|
||||||
self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
|
|
||||||
|
|
||||||
self.conv_1x1 = nn.Sequential()
|
self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
|
||||||
self.conv_1x1.add_module('relu', nn.ReLU())
|
self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
|
||||||
self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
|
|
||||||
self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
|
|
||||||
|
|
||||||
self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
|
self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
|
||||||
self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
|
self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
|
||||||
|
|
||||||
self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
|
self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
|
||||||
self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
|
self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
|
||||||
|
|
||||||
self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
|
self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
|
||||||
self.comb_iter_2_right = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
|
|
||||||
|
|
||||||
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
|
||||||
|
self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
|
||||||
self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
|
|
||||||
self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
|
|
||||||
|
|
||||||
def forward(self, x, x_prev):
|
def forward(self, x, x_prev):
|
||||||
x_left = self.conv_prev_1x1(x_prev)
|
x_left = self.conv_prev_1x1(x_prev)
|
||||||
@ -487,7 +401,7 @@ class NASNetALarge(nn.Module):
|
|||||||
"""NASNetALarge (6 @ 4032) """
|
"""NASNetALarge (6 @ 4032) """
|
||||||
|
|
||||||
def __init__(self, num_classes=1000, in_chans=1, stem_size=96, num_features=4032, channel_multiplier=2,
|
def __init__(self, num_classes=1000, in_chans=1, stem_size=96, num_features=4032, channel_multiplier=2,
|
||||||
drop_rate=0., global_pool='avg'):
|
drop_rate=0., global_pool='avg', pad_type='same'):
|
||||||
super(NASNetALarge, self).__init__()
|
super(NASNetALarge, self).__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.stem_size = stem_size
|
self.stem_size = stem_size
|
||||||
@ -498,60 +412,79 @@ class NASNetALarge(nn.Module):
|
|||||||
channels = self.num_features // 24
|
channels = self.num_features // 24
|
||||||
# 24 is default value for the architecture
|
# 24 is default value for the architecture
|
||||||
|
|
||||||
self.conv0 = nn.Sequential()
|
self.conv0 = ConvBnAct(
|
||||||
self.conv0.add_module('conv', nn.Conv2d(
|
in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2,
|
||||||
in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2, bias=False))
|
norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None)
|
||||||
self.conv0.add_module('bn', nn.BatchNorm2d(self.stem_size, eps=0.001, momentum=0.1, affine=True))
|
|
||||||
|
|
||||||
self.cell_stem_0 = CellStem0(self.stem_size, num_channels=channels // (channel_multiplier ** 2))
|
self.cell_stem_0 = CellStem0(
|
||||||
self.cell_stem_1 = CellStem1(self.stem_size, num_channels=channels // channel_multiplier)
|
self.stem_size, num_channels=channels // (channel_multiplier ** 2), pad_type=pad_type)
|
||||||
|
self.cell_stem_1 = CellStem1(
|
||||||
|
self.stem_size, num_channels=channels // channel_multiplier, pad_type=pad_type)
|
||||||
|
|
||||||
self.cell_0 = FirstCell(in_channels_left=channels, out_channels_left=channels // 2,
|
self.cell_0 = FirstCell(
|
||||||
in_channels_right=2 * channels, out_channels_right=channels)
|
in_chs_left=channels, out_chs_left=channels // 2,
|
||||||
self.cell_1 = NormalCell(in_channels_left=2 * channels, out_channels_left=channels,
|
in_chs_right=2 * channels, out_chs_right=channels, pad_type=pad_type)
|
||||||
in_channels_right=6 * channels, out_channels_right=channels)
|
self.cell_1 = NormalCell(
|
||||||
self.cell_2 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
|
in_chs_left=2 * channels, out_chs_left=channels,
|
||||||
in_channels_right=6 * channels, out_channels_right=channels)
|
in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
|
||||||
self.cell_3 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
|
self.cell_2 = NormalCell(
|
||||||
in_channels_right=6 * channels, out_channels_right=channels)
|
in_chs_left=6 * channels, out_chs_left=channels,
|
||||||
self.cell_4 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
|
in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
|
||||||
in_channels_right=6 * channels, out_channels_right=channels)
|
self.cell_3 = NormalCell(
|
||||||
self.cell_5 = NormalCell(in_channels_left=6 * channels, out_channels_left=channels,
|
in_chs_left=6 * channels, out_chs_left=channels,
|
||||||
in_channels_right=6 * channels, out_channels_right=channels)
|
in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
|
||||||
|
self.cell_4 = NormalCell(
|
||||||
|
in_chs_left=6 * channels, out_chs_left=channels,
|
||||||
|
in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
|
||||||
|
self.cell_5 = NormalCell(
|
||||||
|
in_chs_left=6 * channels, out_chs_left=channels,
|
||||||
|
in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
|
||||||
|
|
||||||
self.reduction_cell_0 = ReductionCell0(in_channels_left=6 * channels, out_channels_left=2 * channels,
|
self.reduction_cell_0 = ReductionCell0(
|
||||||
in_channels_right=6 * channels, out_channels_right=2 * channels)
|
in_chs_left=6 * channels, out_chs_left=2 * channels,
|
||||||
|
in_chs_right=6 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
||||||
|
self.cell_6 = FirstCell(
|
||||||
|
in_chs_left=6 * channels, out_chs_left=channels,
|
||||||
|
in_chs_right=8 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
||||||
|
self.cell_7 = NormalCell(
|
||||||
|
in_chs_left=8 * channels, out_chs_left=2 * channels,
|
||||||
|
in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
||||||
|
self.cell_8 = NormalCell(
|
||||||
|
in_chs_left=12 * channels, out_chs_left=2 * channels,
|
||||||
|
in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
||||||
|
self.cell_9 = NormalCell(
|
||||||
|
in_chs_left=12 * channels, out_chs_left=2 * channels,
|
||||||
|
in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
||||||
|
self.cell_10 = NormalCell(
|
||||||
|
in_chs_left=12 * channels, out_chs_left=2 * channels,
|
||||||
|
in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
||||||
|
self.cell_11 = NormalCell(
|
||||||
|
in_chs_left=12 * channels, out_chs_left=2 * channels,
|
||||||
|
in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
|
||||||
|
|
||||||
self.cell_6 = FirstCell(in_channels_left=6 * channels, out_channels_left=channels,
|
self.reduction_cell_1 = ReductionCell1(
|
||||||
in_channels_right=8 * channels, out_channels_right=2 * channels)
|
in_chs_left=12 * channels, out_chs_left=4 * channels,
|
||||||
self.cell_7 = NormalCell(in_channels_left=8 * channels, out_channels_left=2 * channels,
|
in_chs_right=12 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
||||||
in_channels_right=12 * channels, out_channels_right=2 * channels)
|
self.cell_12 = FirstCell(
|
||||||
self.cell_8 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
|
in_chs_left=12 * channels, out_chs_left=2 * channels,
|
||||||
in_channels_right=12 * channels, out_channels_right=2 * channels)
|
in_chs_right=16 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
||||||
self.cell_9 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
|
self.cell_13 = NormalCell(
|
||||||
in_channels_right=12 * channels, out_channels_right=2 * channels)
|
in_chs_left=16 * channels, out_chs_left=4 * channels,
|
||||||
self.cell_10 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
|
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
||||||
in_channels_right=12 * channels, out_channels_right=2 * channels)
|
self.cell_14 = NormalCell(
|
||||||
self.cell_11 = NormalCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
|
in_chs_left=24 * channels, out_chs_left=4 * channels,
|
||||||
in_channels_right=12 * channels, out_channels_right=2 * channels)
|
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
||||||
|
self.cell_15 = NormalCell(
|
||||||
|
in_chs_left=24 * channels, out_chs_left=4 * channels,
|
||||||
|
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
||||||
|
self.cell_16 = NormalCell(
|
||||||
|
in_chs_left=24 * channels, out_chs_left=4 * channels,
|
||||||
|
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
||||||
|
self.cell_17 = NormalCell(
|
||||||
|
in_chs_left=24 * channels, out_chs_left=4 * channels,
|
||||||
|
in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
|
||||||
|
|
||||||
self.reduction_cell_1 = ReductionCell1(in_channels_left=12 * channels, out_channels_left=4 * channels,
|
self.act = nn.ReLU(inplace=True)
|
||||||
in_channels_right=12 * channels, out_channels_right=4 * channels)
|
|
||||||
|
|
||||||
self.cell_12 = FirstCell(in_channels_left=12 * channels, out_channels_left=2 * channels,
|
|
||||||
in_channels_right=16 * channels, out_channels_right=4 * channels)
|
|
||||||
self.cell_13 = NormalCell(in_channels_left=16 * channels, out_channels_left=4 * channels,
|
|
||||||
in_channels_right=24 * channels, out_channels_right=4 * channels)
|
|
||||||
self.cell_14 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
|
|
||||||
in_channels_right=24 * channels, out_channels_right=4 * channels)
|
|
||||||
self.cell_15 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
|
|
||||||
in_channels_right=24 * channels, out_channels_right=4 * channels)
|
|
||||||
self.cell_16 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
|
|
||||||
in_channels_right=24 * channels, out_channels_right=4 * channels)
|
|
||||||
self.cell_17 = NormalCell(in_channels_left=24 * channels, out_channels_left=4 * channels,
|
|
||||||
in_channels_right=24 * channels, out_channels_right=4 * channels)
|
|
||||||
|
|
||||||
self.relu = nn.ReLU()
|
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||||
|
|
||||||
@ -569,8 +502,11 @@ class NASNetALarge(nn.Module):
|
|||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x_conv0 = self.conv0(x)
|
x_conv0 = self.conv0(x)
|
||||||
|
#0
|
||||||
|
|
||||||
x_stem_0 = self.cell_stem_0(x_conv0)
|
x_stem_0 = self.cell_stem_0(x_conv0)
|
||||||
x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0)
|
x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0)
|
||||||
|
#1
|
||||||
|
|
||||||
x_cell_0 = self.cell_0(x_stem_1, x_stem_0)
|
x_cell_0 = self.cell_0(x_stem_1, x_stem_0)
|
||||||
x_cell_1 = self.cell_1(x_cell_0, x_stem_1)
|
x_cell_1 = self.cell_1(x_cell_0, x_stem_1)
|
||||||
@ -578,25 +514,27 @@ class NASNetALarge(nn.Module):
|
|||||||
x_cell_3 = self.cell_3(x_cell_2, x_cell_1)
|
x_cell_3 = self.cell_3(x_cell_2, x_cell_1)
|
||||||
x_cell_4 = self.cell_4(x_cell_3, x_cell_2)
|
x_cell_4 = self.cell_4(x_cell_3, x_cell_2)
|
||||||
x_cell_5 = self.cell_5(x_cell_4, x_cell_3)
|
x_cell_5 = self.cell_5(x_cell_4, x_cell_3)
|
||||||
|
#2
|
||||||
|
|
||||||
x_reduction_cell_0 = self.reduction_cell_0(x_cell_5, x_cell_4)
|
x_reduction_cell_0 = self.reduction_cell_0(x_cell_5, x_cell_4)
|
||||||
|
|
||||||
x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_4)
|
x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_4)
|
||||||
x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0)
|
x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0)
|
||||||
x_cell_8 = self.cell_8(x_cell_7, x_cell_6)
|
x_cell_8 = self.cell_8(x_cell_7, x_cell_6)
|
||||||
x_cell_9 = self.cell_9(x_cell_8, x_cell_7)
|
x_cell_9 = self.cell_9(x_cell_8, x_cell_7)
|
||||||
x_cell_10 = self.cell_10(x_cell_9, x_cell_8)
|
x_cell_10 = self.cell_10(x_cell_9, x_cell_8)
|
||||||
x_cell_11 = self.cell_11(x_cell_10, x_cell_9)
|
x_cell_11 = self.cell_11(x_cell_10, x_cell_9)
|
||||||
|
#3
|
||||||
|
|
||||||
x_reduction_cell_1 = self.reduction_cell_1(x_cell_11, x_cell_10)
|
x_reduction_cell_1 = self.reduction_cell_1(x_cell_11, x_cell_10)
|
||||||
|
|
||||||
x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_10)
|
x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_10)
|
||||||
x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1)
|
x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1)
|
||||||
x_cell_14 = self.cell_14(x_cell_13, x_cell_12)
|
x_cell_14 = self.cell_14(x_cell_13, x_cell_12)
|
||||||
x_cell_15 = self.cell_15(x_cell_14, x_cell_13)
|
x_cell_15 = self.cell_15(x_cell_14, x_cell_13)
|
||||||
x_cell_16 = self.cell_16(x_cell_15, x_cell_14)
|
x_cell_16 = self.cell_16(x_cell_15, x_cell_14)
|
||||||
x_cell_17 = self.cell_17(x_cell_16, x_cell_15)
|
x_cell_17 = self.cell_17(x_cell_16, x_cell_15)
|
||||||
x = self.relu(x_cell_17)
|
x = self.act(x_cell_17)
|
||||||
|
#4
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -14,7 +14,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .layers import SelectAdaptivePool2d
|
from .layers import SelectAdaptivePool2d, ConvBnAct, create_conv2d, create_pool2d
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
|
|
||||||
__all__ = ['PNASNet5Large']
|
__all__ = ['PNASNet5Large']
|
||||||
@ -35,34 +35,15 @@ default_cfgs = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class MaxPool(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, kernel_size, stride=1, padding=1, zero_pad=False):
|
|
||||||
super(MaxPool, self).__init__()
|
|
||||||
self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0)) if zero_pad else None
|
|
||||||
self.pool = nn.MaxPool2d(kernel_size, stride=stride, padding=padding)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.zero_pad is not None:
|
|
||||||
x = self.zero_pad(x)
|
|
||||||
x = self.pool(x)
|
|
||||||
x = x[:, :, 1:, 1:]
|
|
||||||
else:
|
|
||||||
x = self.pool(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SeparableConv2d(nn.Module):
|
class SeparableConv2d(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, dw_kernel_size, dw_stride,
|
def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''):
|
||||||
dw_padding):
|
|
||||||
super(SeparableConv2d, self).__init__()
|
super(SeparableConv2d, self).__init__()
|
||||||
self.depthwise_conv2d = nn.Conv2d(in_channels, in_channels,
|
self.depthwise_conv2d = create_conv2d(
|
||||||
kernel_size=dw_kernel_size,
|
in_channels, in_channels, kernel_size=kernel_size,
|
||||||
stride=dw_stride, padding=dw_padding,
|
stride=stride, padding=padding, groups=in_channels)
|
||||||
groups=in_channels, bias=False)
|
self.pointwise_conv2d = create_conv2d(
|
||||||
self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels,
|
in_channels, out_channels, kernel_size=1, padding=padding)
|
||||||
kernel_size=1, bias=False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.depthwise_conv2d(x)
|
x = self.depthwise_conv2d(x)
|
||||||
@ -72,50 +53,39 @@ class SeparableConv2d(nn.Module):
|
|||||||
|
|
||||||
class BranchSeparables(nn.Module):
|
class BranchSeparables(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, stem_cell=False, padding=''):
|
||||||
stem_cell=False, zero_pad=False):
|
|
||||||
super(BranchSeparables, self).__init__()
|
super(BranchSeparables, self).__init__()
|
||||||
padding = kernel_size // 2
|
|
||||||
middle_channels = out_channels if stem_cell else in_channels
|
middle_channels = out_channels if stem_cell else in_channels
|
||||||
self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0)) if zero_pad else None
|
self.act_1 = nn.ReLU()
|
||||||
self.relu_1 = nn.ReLU()
|
self.separable_1 = SeparableConv2d(
|
||||||
self.separable_1 = SeparableConv2d(in_channels, middle_channels,
|
in_channels, middle_channels, kernel_size, stride=stride, padding=padding)
|
||||||
kernel_size, dw_stride=stride,
|
|
||||||
dw_padding=padding)
|
|
||||||
self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001)
|
self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001)
|
||||||
self.relu_2 = nn.ReLU()
|
self.act_2 = nn.ReLU()
|
||||||
self.separable_2 = SeparableConv2d(middle_channels, out_channels,
|
self.separable_2 = SeparableConv2d(
|
||||||
kernel_size, dw_stride=1,
|
middle_channels, out_channels, kernel_size, stride=1, padding=padding)
|
||||||
dw_padding=padding)
|
|
||||||
self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001)
|
self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.relu_1(x)
|
x = self.act_1(x)
|
||||||
if self.zero_pad is not None:
|
x = self.separable_1(x)
|
||||||
x = self.zero_pad(x)
|
|
||||||
x = self.separable_1(x)
|
|
||||||
x = x[:, :, 1:, 1:].contiguous()
|
|
||||||
else:
|
|
||||||
x = self.separable_1(x)
|
|
||||||
x = self.bn_sep_1(x)
|
x = self.bn_sep_1(x)
|
||||||
x = self.relu_2(x)
|
x = self.act_2(x)
|
||||||
x = self.separable_2(x)
|
x = self.separable_2(x)
|
||||||
x = self.bn_sep_2(x)
|
x = self.bn_sep_2(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ReluConvBn(nn.Module):
|
class ActConvBn(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''):
|
||||||
super(ReluConvBn, self).__init__()
|
super(ActConvBn, self).__init__()
|
||||||
self.relu = nn.ReLU()
|
self.act = nn.ReLU()
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels,
|
self.conv = create_conv2d(
|
||||||
kernel_size=kernel_size, stride=stride,
|
in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||||
bias=False)
|
|
||||||
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
|
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.relu(x)
|
x = self.act(x)
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
x = self.bn(x)
|
x = self.bn(x)
|
||||||
return x
|
return x
|
||||||
@ -123,32 +93,24 @@ class ReluConvBn(nn.Module):
|
|||||||
|
|
||||||
class FactorizedReduction(nn.Module):
|
class FactorizedReduction(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels):
|
def __init__(self, in_channels, out_channels, padding=''):
|
||||||
super(FactorizedReduction, self).__init__()
|
super(FactorizedReduction, self).__init__()
|
||||||
self.relu = nn.ReLU()
|
self.act = nn.ReLU()
|
||||||
self.path_1 = nn.Sequential(OrderedDict([
|
self.path_1 = nn.Sequential(OrderedDict([
|
||||||
('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)),
|
('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)),
|
||||||
('conv', nn.Conv2d(in_channels, out_channels // 2,
|
('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)),
|
||||||
kernel_size=1, bias=False)),
|
|
||||||
]))
|
]))
|
||||||
self.path_2 = nn.Sequential(OrderedDict([
|
self.path_2 = nn.Sequential(OrderedDict([
|
||||||
('pad', nn.ZeroPad2d((0, 1, 0, 1))),
|
('pad', nn.ZeroPad2d((-1, 1, -1, 1))), # shift
|
||||||
('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)),
|
('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)),
|
||||||
('conv', nn.Conv2d(in_channels, out_channels // 2,
|
('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)),
|
||||||
kernel_size=1, bias=False)),
|
|
||||||
]))
|
]))
|
||||||
self.final_path_bn = nn.BatchNorm2d(out_channels, eps=0.001)
|
self.final_path_bn = nn.BatchNorm2d(out_channels, eps=0.001)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.relu(x)
|
x = self.act(x)
|
||||||
|
|
||||||
x_path1 = self.path_1(x)
|
x_path1 = self.path_1(x)
|
||||||
|
x_path2 = self.path_2(x)
|
||||||
x_path2 = self.path_2.pad(x)
|
|
||||||
x_path2 = x_path2[:, :, 1:, 1:]
|
|
||||||
x_path2 = self.path_2.avgpool(x_path2)
|
|
||||||
x_path2 = self.path_2.conv(x_path2)
|
|
||||||
|
|
||||||
out = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
|
out = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -179,49 +141,41 @@ class CellBase(nn.Module):
|
|||||||
x_comb_iter_4_right = x_right
|
x_comb_iter_4_right = x_right
|
||||||
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
|
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
|
||||||
|
|
||||||
x_out = torch.cat(
|
x_out = torch.cat([x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
|
||||||
[x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
|
|
||||||
return x_out
|
return x_out
|
||||||
|
|
||||||
|
|
||||||
class CellStem0(CellBase):
|
class CellStem0(CellBase):
|
||||||
|
|
||||||
def __init__(self, in_channels_left, out_channels_left, in_channels_right,
|
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, padding=''):
|
||||||
out_channels_right):
|
|
||||||
super(CellStem0, self).__init__()
|
super(CellStem0, self).__init__()
|
||||||
self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right,
|
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=padding)
|
||||||
kernel_size=1)
|
|
||||||
self.comb_iter_0_left = BranchSeparables(in_channels_left,
|
self.comb_iter_0_left = BranchSeparables(
|
||||||
out_channels_left,
|
in_chs_left, out_chs_left, kernel_size=5, stride=2, stem_cell=True, padding=padding)
|
||||||
kernel_size=5, stride=2,
|
|
||||||
stem_cell=True)
|
|
||||||
self.comb_iter_0_right = nn.Sequential(OrderedDict([
|
self.comb_iter_0_right = nn.Sequential(OrderedDict([
|
||||||
('max_pool', MaxPool(3, stride=2)),
|
('max_pool', create_pool2d('max', 3, stride=2, padding=padding)),
|
||||||
('conv', nn.Conv2d(in_channels_left, out_channels_left,
|
('conv', create_conv2d(in_chs_left, out_chs_left, kernel_size=1, padding=padding)),
|
||||||
kernel_size=1, bias=False)),
|
('bn', nn.BatchNorm2d(out_chs_left, eps=0.001)),
|
||||||
('bn', nn.BatchNorm2d(out_channels_left, eps=0.001)),
|
|
||||||
]))
|
]))
|
||||||
self.comb_iter_1_left = BranchSeparables(out_channels_right,
|
|
||||||
out_channels_right,
|
self.comb_iter_1_left = BranchSeparables(
|
||||||
kernel_size=7, stride=2)
|
out_chs_right, out_chs_right, kernel_size=7, stride=2, padding=padding)
|
||||||
self.comb_iter_1_right = MaxPool(3, stride=2)
|
self.comb_iter_1_right = create_pool2d('max', 3, stride=2, padding=padding)
|
||||||
self.comb_iter_2_left = BranchSeparables(out_channels_right,
|
|
||||||
out_channels_right,
|
self.comb_iter_2_left = BranchSeparables(
|
||||||
kernel_size=5, stride=2)
|
out_chs_right, out_chs_right, kernel_size=5, stride=2, padding=padding)
|
||||||
self.comb_iter_2_right = BranchSeparables(out_channels_right,
|
self.comb_iter_2_right = BranchSeparables(
|
||||||
out_channels_right,
|
out_chs_right, out_chs_right, kernel_size=3, stride=2, padding=padding)
|
||||||
kernel_size=3, stride=2)
|
|
||||||
self.comb_iter_3_left = BranchSeparables(out_channels_right,
|
self.comb_iter_3_left = BranchSeparables(
|
||||||
out_channels_right,
|
out_chs_right, out_chs_right, kernel_size=3, padding=padding)
|
||||||
kernel_size=3)
|
self.comb_iter_3_right = create_pool2d('max', 3, stride=2, padding=padding)
|
||||||
self.comb_iter_3_right = MaxPool(3, stride=2)
|
|
||||||
self.comb_iter_4_left = BranchSeparables(in_channels_right,
|
self.comb_iter_4_left = BranchSeparables(
|
||||||
out_channels_right,
|
in_chs_right, out_chs_right, kernel_size=3, stride=2, stem_cell=True, padding=padding)
|
||||||
kernel_size=3, stride=2,
|
self.comb_iter_4_right = ActConvBn(
|
||||||
stem_cell=True)
|
out_chs_right, out_chs_right, kernel_size=1, stride=2, padding=padding)
|
||||||
self.comb_iter_4_right = ReluConvBn(out_channels_right,
|
|
||||||
out_channels_right,
|
|
||||||
kernel_size=1, stride=2)
|
|
||||||
|
|
||||||
def forward(self, x_left):
|
def forward(self, x_left):
|
||||||
x_right = self.conv_1x1(x_left)
|
x_right = self.conv_1x1(x_left)
|
||||||
@ -231,9 +185,8 @@ class CellStem0(CellBase):
|
|||||||
|
|
||||||
class Cell(CellBase):
|
class Cell(CellBase):
|
||||||
|
|
||||||
def __init__(self, in_channels_left, out_channels_left, in_channels_right,
|
def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, padding='',
|
||||||
out_channels_right, is_reduction=False, zero_pad=False,
|
is_reduction=False, match_prev_layer_dims=False):
|
||||||
match_prev_layer_dimensions=False):
|
|
||||||
super(Cell, self).__init__()
|
super(Cell, self).__init__()
|
||||||
|
|
||||||
# If `is_reduction` is set to `True` stride 2 is used for
|
# If `is_reduction` is set to `True` stride 2 is used for
|
||||||
@ -244,45 +197,34 @@ class Cell(CellBase):
|
|||||||
# If `match_prev_layer_dimensions` is set to `True`
|
# If `match_prev_layer_dimensions` is set to `True`
|
||||||
# `FactorizedReduction` is used to reduce the spatial size
|
# `FactorizedReduction` is used to reduce the spatial size
|
||||||
# of the left input of a cell approximately by a factor of 2.
|
# of the left input of a cell approximately by a factor of 2.
|
||||||
self.match_prev_layer_dimensions = match_prev_layer_dimensions
|
self.match_prev_layer_dimensions = match_prev_layer_dims
|
||||||
if match_prev_layer_dimensions:
|
if match_prev_layer_dims:
|
||||||
self.conv_prev_1x1 = FactorizedReduction(in_channels_left,
|
self.conv_prev_1x1 = FactorizedReduction(in_chs_left, out_chs_left, padding=padding)
|
||||||
out_channels_left)
|
|
||||||
else:
|
else:
|
||||||
self.conv_prev_1x1 = ReluConvBn(in_channels_left,
|
self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, kernel_size=1, padding=padding)
|
||||||
out_channels_left, kernel_size=1)
|
self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=padding)
|
||||||
|
|
||||||
self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right,
|
self.comb_iter_0_left = BranchSeparables(
|
||||||
kernel_size=1)
|
out_chs_left, out_chs_left, kernel_size=5, stride=stride, padding=padding)
|
||||||
self.comb_iter_0_left = BranchSeparables(out_channels_left,
|
self.comb_iter_0_right = create_pool2d('max', 3, stride=stride, padding=padding)
|
||||||
out_channels_left,
|
|
||||||
kernel_size=5, stride=stride,
|
self.comb_iter_1_left = BranchSeparables(
|
||||||
zero_pad=zero_pad)
|
out_chs_right, out_chs_right, kernel_size=7, stride=stride, padding=padding)
|
||||||
self.comb_iter_0_right = MaxPool(3, stride=stride, zero_pad=zero_pad)
|
self.comb_iter_1_right = create_pool2d('max', 3, stride=stride, padding=padding)
|
||||||
self.comb_iter_1_left = BranchSeparables(out_channels_right,
|
|
||||||
out_channels_right,
|
self.comb_iter_2_left = BranchSeparables(
|
||||||
kernel_size=7, stride=stride,
|
out_chs_right, out_chs_right, kernel_size=5, stride=stride, padding=padding)
|
||||||
zero_pad=zero_pad)
|
self.comb_iter_2_right = BranchSeparables(
|
||||||
self.comb_iter_1_right = MaxPool(3, stride=stride, zero_pad=zero_pad)
|
out_chs_right, out_chs_right, kernel_size=3, stride=stride, padding=padding)
|
||||||
self.comb_iter_2_left = BranchSeparables(out_channels_right,
|
|
||||||
out_channels_right,
|
self.comb_iter_3_left = BranchSeparables(out_chs_right, out_chs_right, kernel_size=3)
|
||||||
kernel_size=5, stride=stride,
|
self.comb_iter_3_right = create_pool2d('max', 3, stride=stride, padding=padding)
|
||||||
zero_pad=zero_pad)
|
|
||||||
self.comb_iter_2_right = BranchSeparables(out_channels_right,
|
self.comb_iter_4_left = BranchSeparables(
|
||||||
out_channels_right,
|
out_chs_left, out_chs_left, kernel_size=3, stride=stride, padding=padding)
|
||||||
kernel_size=3, stride=stride,
|
|
||||||
zero_pad=zero_pad)
|
|
||||||
self.comb_iter_3_left = BranchSeparables(out_channels_right,
|
|
||||||
out_channels_right,
|
|
||||||
kernel_size=3)
|
|
||||||
self.comb_iter_3_right = MaxPool(3, stride=stride, zero_pad=zero_pad)
|
|
||||||
self.comb_iter_4_left = BranchSeparables(out_channels_left,
|
|
||||||
out_channels_left,
|
|
||||||
kernel_size=3, stride=stride,
|
|
||||||
zero_pad=zero_pad)
|
|
||||||
if is_reduction:
|
if is_reduction:
|
||||||
self.comb_iter_4_right = ReluConvBn(
|
self.comb_iter_4_right = ActConvBn(
|
||||||
out_channels_right, out_channels_right, kernel_size=1, stride=stride)
|
out_chs_right, out_chs_right, kernel_size=1, stride=stride, padding=padding)
|
||||||
else:
|
else:
|
||||||
self.comb_iter_4_right = None
|
self.comb_iter_4_right = None
|
||||||
|
|
||||||
@ -294,52 +236,53 @@ class Cell(CellBase):
|
|||||||
|
|
||||||
|
|
||||||
class PNASNet5Large(nn.Module):
|
class PNASNet5Large(nn.Module):
|
||||||
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0.5, global_pool='avg'):
|
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0.5, global_pool='avg', padding=''):
|
||||||
super(PNASNet5Large, self).__init__()
|
super(PNASNet5Large, self).__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.num_features = 4320
|
self.num_features = 4320
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
|
|
||||||
self.conv_0 = nn.Sequential(OrderedDict([
|
self.conv_0 = ConvBnAct(
|
||||||
('conv', nn.Conv2d(in_chans, 96, kernel_size=3, stride=2, bias=False)),
|
in_chans, 96, kernel_size=3, stride=2, padding=0,
|
||||||
('bn', nn.BatchNorm2d(96, eps=0.001))
|
norm_kwargs=dict(eps=0.001, momentum=0.1), act_layer=None)
|
||||||
]))
|
|
||||||
self.cell_stem_0 = CellStem0(in_channels_left=96, out_channels_left=54,
|
self.cell_stem_0 = CellStem0(
|
||||||
in_channels_right=96,
|
in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, padding=padding)
|
||||||
out_channels_right=54)
|
|
||||||
self.cell_stem_1 = Cell(in_channels_left=96, out_channels_left=108,
|
self.cell_stem_1 = Cell(
|
||||||
in_channels_right=270, out_channels_right=108,
|
in_chs_left=96, out_chs_left=108, in_chs_right=270, out_chs_right=108, padding=padding,
|
||||||
match_prev_layer_dimensions=True,
|
match_prev_layer_dims=True, is_reduction=True)
|
||||||
is_reduction=True)
|
self.cell_0 = Cell(
|
||||||
self.cell_0 = Cell(in_channels_left=270, out_channels_left=216,
|
in_chs_left=270, out_chs_left=216, in_chs_right=540, out_chs_right=216, padding=padding,
|
||||||
in_channels_right=540, out_channels_right=216,
|
match_prev_layer_dims=True)
|
||||||
match_prev_layer_dimensions=True)
|
self.cell_1 = Cell(
|
||||||
self.cell_1 = Cell(in_channels_left=540, out_channels_left=216,
|
in_chs_left=540, out_chs_left=216, in_chs_right=1080, out_chs_right=216, padding=padding)
|
||||||
in_channels_right=1080, out_channels_right=216)
|
self.cell_2 = Cell(
|
||||||
self.cell_2 = Cell(in_channels_left=1080, out_channels_left=216,
|
in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, padding=padding)
|
||||||
in_channels_right=1080, out_channels_right=216)
|
self.cell_3 = Cell(
|
||||||
self.cell_3 = Cell(in_channels_left=1080, out_channels_left=216,
|
in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, padding=padding)
|
||||||
in_channels_right=1080, out_channels_right=216)
|
|
||||||
self.cell_4 = Cell(in_channels_left=1080, out_channels_left=432,
|
self.cell_4 = Cell(
|
||||||
in_channels_right=1080, out_channels_right=432,
|
in_chs_left=1080, out_chs_left=432, in_chs_right=1080, out_chs_right=432, padding=padding,
|
||||||
is_reduction=True, zero_pad=True)
|
is_reduction=True)
|
||||||
self.cell_5 = Cell(in_channels_left=1080, out_channels_left=432,
|
self.cell_5 = Cell(
|
||||||
in_channels_right=2160, out_channels_right=432,
|
in_chs_left=1080, out_chs_left=432, in_chs_right=2160, out_chs_right=432, padding=padding,
|
||||||
match_prev_layer_dimensions=True)
|
match_prev_layer_dims=True)
|
||||||
self.cell_6 = Cell(in_channels_left=2160, out_channels_left=432,
|
self.cell_6 = Cell(
|
||||||
in_channels_right=2160, out_channels_right=432)
|
in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, padding=padding)
|
||||||
self.cell_7 = Cell(in_channels_left=2160, out_channels_left=432,
|
self.cell_7 = Cell(
|
||||||
in_channels_right=2160, out_channels_right=432)
|
in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, padding=padding)
|
||||||
self.cell_8 = Cell(in_channels_left=2160, out_channels_left=864,
|
|
||||||
in_channels_right=2160, out_channels_right=864,
|
self.cell_8 = Cell(
|
||||||
is_reduction=True)
|
in_chs_left=2160, out_chs_left=864, in_chs_right=2160, out_chs_right=864, padding=padding,
|
||||||
self.cell_9 = Cell(in_channels_left=2160, out_channels_left=864,
|
is_reduction=True)
|
||||||
in_channels_right=4320, out_channels_right=864,
|
self.cell_9 = Cell(
|
||||||
match_prev_layer_dimensions=True)
|
in_chs_left=2160, out_chs_left=864, in_chs_right=4320, out_chs_right=864, padding=padding,
|
||||||
self.cell_10 = Cell(in_channels_left=4320, out_channels_left=864,
|
match_prev_layer_dims=True)
|
||||||
in_channels_right=4320, out_channels_right=864)
|
self.cell_10 = Cell(
|
||||||
self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864,
|
in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, padding=padding)
|
||||||
in_channels_right=4320, out_channels_right=864)
|
self.cell_11 = Cell(
|
||||||
|
in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, padding=padding)
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||||
@ -391,7 +334,7 @@ def pnasnet5large(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|||||||
<https://arxiv.org/abs/1712.00559>`_ paper.
|
<https://arxiv.org/abs/1712.00559>`_ paper.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['pnasnet5large']
|
default_cfg = default_cfgs['pnasnet5large']
|
||||||
model = PNASNet5Large(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
model = PNASNet5Large(num_classes=num_classes, in_chans=in_chans, padding='same', **kwargs)
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
@ -10,7 +10,7 @@ import torch.nn as nn
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .resnet import ResNet
|
from .resnet import _create_resnet_with_cfg
|
||||||
|
|
||||||
__all__ = []
|
__all__ = []
|
||||||
|
|
||||||
@ -132,113 +132,83 @@ class Bottle2neck(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _create_res2net(variant, pretrained=False, **kwargs):
|
||||||
|
default_cfg = default_cfgs[variant]
|
||||||
|
return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def res2net50_26w_4s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def res2net50_26w_4s(pretrained=False, **kwargs):
|
||||||
"""Constructs a Res2Net-50_26w_4s model.
|
"""Constructs a Res2Net-50 26w4s model.
|
||||||
Args:
|
Args:
|
||||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['res2net50_26w_4s']
|
model_args = dict(
|
||||||
res2net_block_args = dict(scale=4)
|
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4), **kwargs)
|
||||||
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=26,
|
return _create_res2net('res2net50_26w_4s', pretrained, **model_args)
|
||||||
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def res2net101_26w_4s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def res2net101_26w_4s(pretrained=False, **kwargs):
|
||||||
"""Constructs a Res2Net-50_26w_4s model.
|
"""Constructs a Res2Net-101 26w4s model.
|
||||||
Args:
|
Args:
|
||||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['res2net101_26w_4s']
|
model_args = dict(
|
||||||
res2net_block_args = dict(scale=4)
|
block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4), **kwargs)
|
||||||
model = ResNet(Bottle2neck, [3, 4, 23, 3], base_width=26,
|
return _create_res2net('res2net101_26w_4s', pretrained, **model_args)
|
||||||
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def res2net50_26w_6s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def res2net50_26w_6s(pretrained=False, **kwargs):
|
||||||
"""Constructs a Res2Net-50_26w_4s model.
|
"""Constructs a Res2Net-50 26w6s model.
|
||||||
Args:
|
Args:
|
||||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['res2net50_26w_6s']
|
model_args = dict(
|
||||||
res2net_block_args = dict(scale=6)
|
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6), **kwargs)
|
||||||
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=26,
|
return _create_res2net('res2net50_26w_6s', pretrained, **model_args)
|
||||||
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def res2net50_26w_8s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def res2net50_26w_8s(pretrained=False, **kwargs):
|
||||||
"""Constructs a Res2Net-50_26w_4s model.
|
"""Constructs a Res2Net-50 26w8s model.
|
||||||
Args:
|
Args:
|
||||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['res2net50_26w_8s']
|
model_args = dict(
|
||||||
res2net_block_args = dict(scale=8)
|
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8), **kwargs)
|
||||||
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=26,
|
return _create_res2net('res2net50_26w_8s', pretrained, **model_args)
|
||||||
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def res2net50_48w_2s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def res2net50_48w_2s(pretrained=False, **kwargs):
|
||||||
"""Constructs a Res2Net-50_48w_2s model.
|
"""Constructs a Res2Net-50 48w2s model.
|
||||||
Args:
|
Args:
|
||||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['res2net50_48w_2s']
|
model_args = dict(
|
||||||
res2net_block_args = dict(scale=2)
|
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2), **kwargs)
|
||||||
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=48,
|
return _create_res2net('res2net50_26w_8s', pretrained, **model_args)
|
||||||
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def res2net50_14w_8s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def res2net50_14w_8s(pretrained=False, **kwargs):
|
||||||
"""Constructs a Res2Net-50_14w_8s model.
|
"""Constructs a Res2Net-50 14w8s model.
|
||||||
Args:
|
Args:
|
||||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['res2net50_14w_8s']
|
model_args = dict(
|
||||||
res2net_block_args = dict(scale=8)
|
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8), **kwargs)
|
||||||
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=14, num_classes=num_classes, in_chans=in_chans,
|
return _create_res2net('res2net50_26w_8s', pretrained, **model_args)
|
||||||
block_args=res2net_block_args, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def res2next50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def res2next50(pretrained=False, **kwargs):
|
||||||
"""Construct Res2NeXt-50 4s
|
"""Construct Res2NeXt-50 4s
|
||||||
Args:
|
Args:
|
||||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['res2next50']
|
model_args = dict(
|
||||||
res2net_block_args = dict(scale=4)
|
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4), **kwargs)
|
||||||
model = ResNet(Bottle2neck, [3, 4, 6, 3], base_width=4, cardinality=8,
|
return _create_res2net('res2next50', pretrained, **model_args)
|
||||||
num_classes=num_classes, in_chans=in_chans, block_args=res2net_block_args, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
@ -6,18 +6,14 @@ Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang198
|
|||||||
|
|
||||||
Modified for torchscript compat, and consistency with timm by Ross Wightman
|
Modified for torchscript compat, and consistency with timm by Ross Wightman
|
||||||
"""
|
"""
|
||||||
import math
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.models.layers import DropBlock2d
|
from timm.models.layers import DropBlock2d
|
||||||
from .helpers import load_pretrained
|
|
||||||
from .layers import SelectiveKernelConv, ConvBnAct, create_attn
|
|
||||||
from .layers.split_attn import SplitAttnConv2d
|
from .layers.split_attn import SplitAttnConv2d
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .resnet import ResNet
|
from .resnet import _create_resnet_with_cfg
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url='', **kwargs):
|
def _cfg(url='', **kwargs):
|
||||||
@ -143,125 +139,98 @@ class ResNestBottleneck(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _create_resnest(variant, pretrained=False, **kwargs):
|
||||||
|
default_cfg = default_cfgs[variant]
|
||||||
|
return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnest14d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnest14d(pretrained=False, **kwargs):
|
||||||
""" ResNeSt-14d model. Weights ported from GluonCV.
|
""" ResNeSt-14d model. Weights ported from GluonCV.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnest14d']
|
model_kwargs = dict(
|
||||||
model = ResNet(
|
block=ResNestBottleneck, layers=[1, 1, 1, 1],
|
||||||
ResNestBottleneck, [1, 1, 1, 1], num_classes=num_classes, in_chans=in_chans,
|
|
||||||
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
|
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
|
||||||
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
||||||
model.default_cfg = default_cfg
|
return _create_resnest('resnest14d', pretrained=pretrained, **model_kwargs)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnest26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnest26d(pretrained=False, **kwargs):
|
||||||
""" ResNeSt-26d model. Weights ported from GluonCV.
|
""" ResNeSt-26d model. Weights ported from GluonCV.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnest26d']
|
model_kwargs = dict(
|
||||||
model = ResNet(
|
block=ResNestBottleneck, layers=[2, 2, 2, 2],
|
||||||
ResNestBottleneck, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
|
|
||||||
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
|
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
|
||||||
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
||||||
model.default_cfg = default_cfg
|
return _create_resnest('resnest26d', pretrained=pretrained, **model_kwargs)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnest50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnest50d(pretrained=False, **kwargs):
|
||||||
""" ResNeSt-50d model. Matches paper ResNeSt-50 model, https://arxiv.org/abs/2004.08955
|
""" ResNeSt-50d model. Matches paper ResNeSt-50 model, https://arxiv.org/abs/2004.08955
|
||||||
Since this codebase supports all possible variations, 'd' for deep stem, stem_width 32, avg in downsample.
|
Since this codebase supports all possible variations, 'd' for deep stem, stem_width 32, avg in downsample.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnest50d']
|
model_kwargs = dict(
|
||||||
model = ResNet(
|
block=ResNestBottleneck, layers=[3, 4, 6, 3],
|
||||||
ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
|
||||||
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
|
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
|
||||||
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
||||||
model.default_cfg = default_cfg
|
return _create_resnest('resnest50d', pretrained=pretrained, **model_kwargs)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnest101e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnest101e(pretrained=False, **kwargs):
|
||||||
""" ResNeSt-101e model. Matches paper ResNeSt-101 model, https://arxiv.org/abs/2004.08955
|
""" ResNeSt-101e model. Matches paper ResNeSt-101 model, https://arxiv.org/abs/2004.08955
|
||||||
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
|
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnest101e']
|
model_kwargs = dict(
|
||||||
model = ResNet(
|
block=ResNestBottleneck, layers=[3, 4, 23, 3],
|
||||||
ResNestBottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans,
|
|
||||||
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
|
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
|
||||||
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
||||||
model.default_cfg = default_cfg
|
return _create_resnest('resnest101e', pretrained=pretrained, **model_kwargs)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnest200e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnest200e(pretrained=False, **kwargs):
|
||||||
""" ResNeSt-200e model. Matches paper ResNeSt-200 model, https://arxiv.org/abs/2004.08955
|
""" ResNeSt-200e model. Matches paper ResNeSt-200 model, https://arxiv.org/abs/2004.08955
|
||||||
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
|
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnest200e']
|
model_kwargs = dict(
|
||||||
model = ResNet(
|
block=ResNestBottleneck, layers=[3, 24, 36, 3],
|
||||||
ResNestBottleneck, [3, 24, 36, 3], num_classes=num_classes, in_chans=in_chans,
|
|
||||||
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
|
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
|
||||||
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
||||||
model.default_cfg = default_cfg
|
return _create_resnest('resnest200e', pretrained=pretrained, **model_kwargs)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnest269e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnest269e(pretrained=False, **kwargs):
|
||||||
""" ResNeSt-269e model. Matches paper ResNeSt-269 model, https://arxiv.org/abs/2004.08955
|
""" ResNeSt-269e model. Matches paper ResNeSt-269 model, https://arxiv.org/abs/2004.08955
|
||||||
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
|
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnest269e']
|
model_kwargs = dict(
|
||||||
model = ResNet(
|
block=ResNestBottleneck, layers=[3, 30, 48, 8],
|
||||||
ResNestBottleneck, [3, 30, 48, 8], num_classes=num_classes, in_chans=in_chans,
|
|
||||||
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
|
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
|
||||||
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
||||||
model.default_cfg = default_cfg
|
return _create_resnest('resnest269e', pretrained=pretrained, **model_kwargs)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnest50d_4s2x40d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnest50d_4s2x40d(pretrained=False, **kwargs):
|
||||||
"""ResNeSt-50 4s2x40d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md
|
"""ResNeSt-50 4s2x40d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnest50d_4s2x40d']
|
model_kwargs = dict(
|
||||||
model = ResNet(
|
block=ResNestBottleneck, layers=[3, 4, 6, 3],
|
||||||
ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
|
||||||
stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2,
|
stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2,
|
||||||
block_args=dict(radix=4, avd=True, avd_first=True), **kwargs)
|
block_args=dict(radix=4, avd=True, avd_first=True), **kwargs)
|
||||||
model.default_cfg = default_cfg
|
return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **model_kwargs)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnest50d_1s4x24d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnest50d_1s4x24d(pretrained=False, **kwargs):
|
||||||
"""ResNeSt-50 1s4x24d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md
|
"""ResNeSt-50 1s4x24d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnest50d_1s4x24d']
|
model_kwargs = dict(
|
||||||
model = ResNet(
|
block=ResNestBottleneck, layers=[3, 4, 6, 3],
|
||||||
ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
|
||||||
stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4,
|
stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4,
|
||||||
block_args=dict(radix=1, avd=True, avd_first=True), **kwargs)
|
block_args=dict(radix=1, avd=True, avd_first=True), **kwargs)
|
||||||
model.default_cfg = default_cfg
|
return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **model_kwargs)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
@ -6,11 +6,14 @@ additional dropout and dynamic global avg/max pool.
|
|||||||
ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman
|
ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from .features import FeatureNet
|
||||||
from .helpers import load_pretrained, adapt_model_from_file
|
from .helpers import load_pretrained, adapt_model_from_file
|
||||||
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d
|
from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn, BlurPool2d
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
@ -390,6 +393,7 @@ class ResNet(nn.Module):
|
|||||||
self.base_width = base_width
|
self.base_width = base_width
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
self.expansion = block.expansion
|
self.expansion = block.expansion
|
||||||
|
self.feature_info = [dict(num_chs=self.inplanes, reduction=2, module='act1')]
|
||||||
super(ResNet, self).__init__()
|
super(ResNet, self).__init__()
|
||||||
|
|
||||||
# Stem
|
# Stem
|
||||||
@ -420,9 +424,6 @@ class ResNet(nn.Module):
|
|||||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
|
|
||||||
# Feature Blocks
|
# Feature Blocks
|
||||||
dp = DropPath(drop_path_rate) if drop_path_rate else None
|
|
||||||
db_3 = DropBlock2d(drop_block_rate, 7, 0.25) if drop_block_rate else None
|
|
||||||
db_4 = DropBlock2d(drop_block_rate, 7, 1.00) if drop_block_rate else None
|
|
||||||
channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4
|
channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4
|
||||||
if output_stride == 16:
|
if output_stride == 16:
|
||||||
strides[3] = 1
|
strides[3] = 1
|
||||||
@ -432,14 +433,23 @@ class ResNet(nn.Module):
|
|||||||
dilations[2:4] = [2, 4]
|
dilations[2:4] = [2, 4]
|
||||||
else:
|
else:
|
||||||
assert output_stride == 32
|
assert output_stride == 32
|
||||||
|
dp = DropPath(drop_path_rate) if drop_path_rate else None
|
||||||
|
db = [
|
||||||
|
None, None,
|
||||||
|
DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None,
|
||||||
|
DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None]
|
||||||
layer_args = list(zip(channels, layers, strides, dilations))
|
layer_args = list(zip(channels, layers, strides, dilations))
|
||||||
layer_kwargs = dict(
|
layer_kwargs = dict(
|
||||||
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
|
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
|
||||||
avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args)
|
avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args)
|
||||||
self.layer1 = self._make_layer(block, *layer_args[0], **layer_kwargs)
|
current_stride = 4
|
||||||
self.layer2 = self._make_layer(block, *layer_args[1], **layer_kwargs)
|
for i in range(4):
|
||||||
self.layer3 = self._make_layer(block, drop_block=db_3, *layer_args[2], **layer_kwargs)
|
layer_name = f'layer{i + 1}'
|
||||||
self.layer4 = self._make_layer(block, drop_block=db_4, *layer_args[3], **layer_kwargs)
|
self.add_module(layer_name, self._make_layer(
|
||||||
|
block, *layer_args[i], drop_block=db[i], **layer_kwargs))
|
||||||
|
current_stride *= strides[i]
|
||||||
|
self.feature_info.append(dict(
|
||||||
|
num_chs=self.inplanes, reduction=current_stride, module=layer_name))
|
||||||
|
|
||||||
# Head (Pooling and Classifier)
|
# Head (Pooling and Classifier)
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
@ -509,245 +519,185 @@ class ResNet(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _create_resnet_with_cfg(variant, default_cfg, pretrained=False, **kwargs):
|
||||||
|
assert isinstance(default_cfg, dict)
|
||||||
|
load_strict, features = True, False
|
||||||
|
out_indices = None
|
||||||
|
if kwargs.pop('features_only', False):
|
||||||
|
load_strict, features = False, True
|
||||||
|
kwargs.pop('num_classes', 0)
|
||||||
|
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
|
||||||
|
model = ResNet(**kwargs)
|
||||||
|
model.default_cfg = copy.deepcopy(default_cfg)
|
||||||
|
if kwargs.pop('pruned', False):
|
||||||
|
model = adapt_model_from_file(model, variant)
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(
|
||||||
|
model,
|
||||||
|
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict)
|
||||||
|
if features:
|
||||||
|
model = FeatureNet(model, out_indices=out_indices)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _create_resnet(variant, pretrained=False, **kwargs):
|
||||||
|
default_cfg = default_cfgs[variant]
|
||||||
|
return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnet18(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNet-18 model.
|
"""Constructs a ResNet-18 model.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnet18']
|
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs)
|
||||||
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
return _create_resnet('resnet18', pretrained, **model_args)
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnet34(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNet-34 model.
|
"""Constructs a ResNet-34 model.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnet34']
|
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)
|
||||||
model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
return _create_resnet('resnet34', pretrained, **model_args)
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnet26(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnet26(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNet-26 model.
|
"""Constructs a ResNet-26 model.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnet26']
|
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], **kwargs)
|
||||||
model = ResNet(Bottleneck, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
return _create_resnet('resnet26', pretrained, **model_args)
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnet26d(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNet-26 v1d model.
|
"""Constructs a ResNet-26 v1d model.
|
||||||
This is technically a 28 layer ResNet, sticking with 'd' modifier from Gluon for now.
|
This is technically a 28 layer ResNet, sticking with 'd' modifier from Gluon for now.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnet26d']
|
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_type='deep', avg_down=True, **kwargs)
|
||||||
model = ResNet(
|
return _create_resnet('resnet26d', pretrained, **model_args)
|
||||||
Bottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True,
|
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnet50(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNet-50 model.
|
"""Constructs a ResNet-50 model.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnet50']
|
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
|
||||||
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
return _create_resnet('resnet50', pretrained, **model_args)
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnet50d(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNet-50-D model.
|
"""Constructs a ResNet-50-D model.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnet50d']
|
model_args = dict(
|
||||||
model = ResNet(
|
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
|
||||||
Bottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
return _create_resnet('resnet50d', pretrained, **model_args)
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnet101(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNet-101 model.
|
"""Constructs a ResNet-101 model.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnet101']
|
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs)
|
||||||
model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
return _create_resnet('resnet101', pretrained, **model_args)
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnet152(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNet-152 model.
|
"""Constructs a ResNet-152 model.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnet152']
|
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs)
|
||||||
model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
return _create_resnet('resnet152', pretrained, **model_args)
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def tv_resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def tv_resnet34(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNet-34 model with original Torchvision weights.
|
"""Constructs a ResNet-34 model with original Torchvision weights.
|
||||||
"""
|
"""
|
||||||
model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)
|
||||||
model.default_cfg = default_cfgs['tv_resnet34']
|
return _create_resnet('tv_resnet34', pretrained, **model_args)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, model.default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def tv_resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def tv_resnet50(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNet-50 model with original Torchvision weights.
|
"""Constructs a ResNet-50 model with original Torchvision weights.
|
||||||
"""
|
"""
|
||||||
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
|
||||||
model.default_cfg = default_cfgs['tv_resnet50']
|
return _create_resnet('tv_resnet50', pretrained, **model_args)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, model.default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def wide_resnet50_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def wide_resnet50_2(pretrained=False, **kwargs):
|
||||||
"""Constructs a Wide ResNet-50-2 model.
|
"""Constructs a Wide ResNet-50-2 model.
|
||||||
The model is the same as ResNet except for the bottleneck number of channels
|
The model is the same as ResNet except for the bottleneck number of channels
|
||||||
which is twice larger in every block. The number of channels in outer 1x1
|
which is twice larger in every block. The number of channels in outer 1x1
|
||||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||||
"""
|
"""
|
||||||
model = ResNet(
|
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128, **kwargs)
|
||||||
Bottleneck, [3, 4, 6, 3], base_width=128,
|
return _create_resnet('wide_resnet50_2', pretrained, **model_args)
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfgs['wide_resnet50_2']
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, model.default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def wide_resnet101_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def wide_resnet101_2(pretrained=False, **kwargs):
|
||||||
"""Constructs a Wide ResNet-101-2 model.
|
"""Constructs a Wide ResNet-101-2 model.
|
||||||
The model is the same as ResNet except for the bottleneck number of channels
|
The model is the same as ResNet except for the bottleneck number of channels
|
||||||
which is twice larger in every block. The number of channels in outer 1x1
|
which is twice larger in every block. The number of channels in outer 1x1
|
||||||
convolutions is the same.
|
convolutions is the same.
|
||||||
"""
|
"""
|
||||||
model = ResNet(
|
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], base_width=128, **kwargs)
|
||||||
Bottleneck, [3, 4, 23, 3], base_width=128,
|
return _create_resnet('wide_resnet101_2', pretrained, **model_args)
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfgs['wide_resnet101_2']
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, model.default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnext50_32x4d(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNeXt50-32x4d model.
|
"""Constructs a ResNeXt50-32x4d model.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnext50_32x4d']
|
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
|
||||||
model = ResNet(
|
return _create_resnet('resnext50_32x4d', pretrained, **model_args)
|
||||||
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
|
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnext50d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnext50d_32x4d(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNeXt50d-32x4d model. ResNext50 w/ deep stem & avg pool downsample
|
"""Constructs a ResNeXt50d-32x4d model. ResNext50 w/ deep stem & avg pool downsample
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnext50d_32x4d']
|
model_args = dict(
|
||||||
model = ResNet(
|
block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
|
||||||
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
|
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
|
||||||
stem_width=32, stem_type='deep', avg_down=True,
|
return _create_resnet('resnext50d_32x4d', pretrained, **model_args)
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnext101_32x4d(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNeXt-101 32x4d model.
|
"""Constructs a ResNeXt-101 32x4d model.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnext101_32x4d']
|
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
|
||||||
model = ResNet(
|
return _create_resnet('resnext101_32x4d', pretrained, **model_args)
|
||||||
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4,
|
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnext101_32x8d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnext101_32x8d(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNeXt-101 32x8d model.
|
"""Constructs a ResNeXt-101 32x8d model.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnext101_32x8d']
|
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
|
||||||
model = ResNet(
|
return _create_resnet('resnext101_32x8d', pretrained, **model_args)
|
||||||
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8,
|
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnext101_64x4d(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNeXt101-64x4d model.
|
"""Constructs a ResNeXt101-64x4d model.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnext101_32x4d']
|
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, **kwargs)
|
||||||
model = ResNet(
|
return _create_resnet('resnext101_64x4d', pretrained, **model_args)
|
||||||
Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4,
|
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def tv_resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def tv_resnext50_32x4d(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNeXt50-32x4d model with original Torchvision weights.
|
"""Constructs a ResNeXt50-32x4d model with original Torchvision weights.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['tv_resnext50_32x4d']
|
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
|
||||||
model = ResNet(
|
return _create_resnet('tv_resnext50_32x4d', pretrained, **model_args)
|
||||||
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
|
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
@ -757,11 +707,8 @@ def ig_resnext101_32x8d(pretrained=True, **kwargs):
|
|||||||
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
|
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
|
||||||
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
|
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
|
||||||
"""
|
"""
|
||||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
|
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
|
||||||
model.default_cfg = default_cfgs['ig_resnext101_32x8d']
|
return _create_resnet('ig_resnext101_32x8d', pretrained, **model_args)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
@ -771,11 +718,8 @@ def ig_resnext101_32x16d(pretrained=True, **kwargs):
|
|||||||
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
|
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
|
||||||
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
|
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
|
||||||
"""
|
"""
|
||||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
|
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
|
||||||
model.default_cfg = default_cfgs['ig_resnext101_32x16d']
|
return _create_resnet('ig_resnext101_32x16d', pretrained, **model_args)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
@ -785,11 +729,8 @@ def ig_resnext101_32x32d(pretrained=True, **kwargs):
|
|||||||
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
|
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
|
||||||
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
|
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
|
||||||
"""
|
"""
|
||||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=32, **kwargs)
|
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=32, **kwargs)
|
||||||
model.default_cfg = default_cfgs['ig_resnext101_32x32d']
|
return _create_resnet('ig_resnext101_32x32d', pretrained, **model_args)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
@ -799,11 +740,8 @@ def ig_resnext101_32x48d(pretrained=True, **kwargs):
|
|||||||
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
|
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
|
||||||
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
|
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
|
||||||
"""
|
"""
|
||||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=48, **kwargs)
|
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=48, **kwargs)
|
||||||
model.default_cfg = default_cfgs['ig_resnext101_32x48d']
|
return _create_resnet('ig_resnext101_32x48d', pretrained, **model_args)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
@ -812,11 +750,8 @@ def ssl_resnet18(pretrained=True, **kwargs):
|
|||||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||||
"""
|
"""
|
||||||
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs)
|
||||||
model.default_cfg = default_cfgs['ssl_resnet18']
|
return _create_resnet('ssl_resnet18', pretrained, **model_args)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
@ -825,11 +760,8 @@ def ssl_resnet50(pretrained=True, **kwargs):
|
|||||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||||
"""
|
"""
|
||||||
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
|
||||||
model.default_cfg = default_cfgs['ssl_resnet50']
|
return _create_resnet('ssl_resnet50', pretrained, **model_args)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
@ -838,11 +770,8 @@ def ssl_resnext50_32x4d(pretrained=True, **kwargs):
|
|||||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||||
"""
|
"""
|
||||||
model = ResNet(Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
|
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
|
||||||
model.default_cfg = default_cfgs['ssl_resnext50_32x4d']
|
return _create_resnet('ssl_resnext50_32x4d', pretrained, **model_args)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
@ -851,11 +780,8 @@ def ssl_resnext101_32x4d(pretrained=True, **kwargs):
|
|||||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||||
"""
|
"""
|
||||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
|
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
|
||||||
model.default_cfg = default_cfgs['ssl_resnext101_32x4d']
|
return _create_resnet('ssl_resnext101_32x4d', pretrained, **model_args)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
@ -864,11 +790,8 @@ def ssl_resnext101_32x8d(pretrained=True, **kwargs):
|
|||||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||||
"""
|
"""
|
||||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
|
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
|
||||||
model.default_cfg = default_cfgs['ssl_resnext101_32x8d']
|
return _create_resnet('ssl_resnext101_32x8d', pretrained, **model_args)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
@ -877,11 +800,8 @@ def ssl_resnext101_32x16d(pretrained=True, **kwargs):
|
|||||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||||
"""
|
"""
|
||||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
|
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
|
||||||
model.default_cfg = default_cfgs['ssl_resnext101_32x16d']
|
return _create_resnet('ssl_resnext101_32x16d', pretrained, **model_args)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
@ -891,11 +811,8 @@ def swsl_resnet18(pretrained=True, **kwargs):
|
|||||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||||
"""
|
"""
|
||||||
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs)
|
||||||
model.default_cfg = default_cfgs['swsl_resnet18']
|
return _create_resnet('swsl_resnet18', pretrained, **model_args)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
@ -905,11 +822,8 @@ def swsl_resnet50(pretrained=True, **kwargs):
|
|||||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||||
"""
|
"""
|
||||||
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
|
||||||
model.default_cfg = default_cfgs['swsl_resnet50']
|
return _create_resnet('swsl_resnet50', pretrained, **model_args)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
@ -919,11 +833,8 @@ def swsl_resnext50_32x4d(pretrained=True, **kwargs):
|
|||||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||||
"""
|
"""
|
||||||
model = ResNet(Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
|
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
|
||||||
model.default_cfg = default_cfgs['swsl_resnext50_32x4d']
|
return _create_resnet('swsl_resnext50_32x4d', pretrained, **model_args)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
@ -933,11 +844,8 @@ def swsl_resnext101_32x4d(pretrained=True, **kwargs):
|
|||||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||||
"""
|
"""
|
||||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
|
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
|
||||||
model.default_cfg = default_cfgs['swsl_resnext101_32x4d']
|
return _create_resnet('swsl_resnext101_32x4d', pretrained, **model_args)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
@ -947,11 +855,8 @@ def swsl_resnext101_32x8d(pretrained=True, **kwargs):
|
|||||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||||
"""
|
"""
|
||||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
|
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
|
||||||
model.default_cfg = default_cfgs['swsl_resnext101_32x8d']
|
return _create_resnet('swsl_resnext101_32x8d', pretrained, **model_args)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
@ -961,61 +866,44 @@ def swsl_resnext101_32x16d(pretrained=True, **kwargs):
|
|||||||
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
|
||||||
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
|
||||||
"""
|
"""
|
||||||
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
|
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
|
||||||
model.default_cfg = default_cfgs['swsl_resnext101_32x16d']
|
return _create_resnet('swsl_resnext101_32x16d', pretrained, **model_args)
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def seresnext26d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def seresnext26d_32x4d(pretrained=False, **kwargs):
|
||||||
"""Constructs a SE-ResNeXt-26-D model.
|
"""Constructs a SE-ResNeXt-26-D model.
|
||||||
This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for
|
This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for
|
||||||
combination of deep stem and avg_pool in downsample.
|
combination of deep stem and avg_pool in downsample.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['seresnext26d_32x4d']
|
model_args = dict(
|
||||||
model = ResNet(
|
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
||||||
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, stem_type='deep', avg_down=True,
|
stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
|
||||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs)
|
return _create_resnet('seresnext26d_32x4d', pretrained, **model_args)
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def seresnext26t_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def seresnext26t_32x4d(pretrained=False, **kwargs):
|
||||||
"""Constructs a SE-ResNet-26-T model.
|
"""Constructs a SE-ResNet-26-T model.
|
||||||
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 48, 64 channels
|
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 48, 64 channels
|
||||||
in the deep stem.
|
in the deep stem.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['seresnext26t_32x4d']
|
model_args = dict(
|
||||||
model = ResNet(
|
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
||||||
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
|
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
|
||||||
stem_width=32, stem_type='deep_tiered', avg_down=True,
|
return _create_resnet('seresnext26t_32x4d', pretrained, **model_args)
|
||||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def seresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def seresnext26tn_32x4d(pretrained=False, **kwargs):
|
||||||
"""Constructs a SE-ResNeXt-26-TN model.
|
"""Constructs a SE-ResNeXt-26-TN model.
|
||||||
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
|
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
|
||||||
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
|
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['seresnext26tn_32x4d']
|
model_args = dict(
|
||||||
model = ResNet(
|
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
||||||
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
|
stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
|
||||||
stem_width=32, stem_type='deep_tiered_narrow', avg_down=True,
|
return _create_resnet('seresnext26tn_32x4d', pretrained, **model_args)
|
||||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
@ -1025,145 +913,91 @@ def ecaresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwarg
|
|||||||
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
|
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
|
||||||
this model replaces SE module with the ECA module
|
this model replaces SE module with the ECA module
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['ecaresnext26tn_32x4d']
|
model_args = dict(
|
||||||
block_args = dict(attn_layer='eca')
|
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
||||||
model = ResNet(
|
stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
|
||||||
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
|
return _create_resnet('ecaresnext26tn_32x4d', pretrained, **model_args)
|
||||||
stem_width=32, stem_type='deep_tiered_narrow', avg_down=True,
|
|
||||||
num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def ecaresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def ecaresnet18(pretrained=False, **kwargs):
|
||||||
""" Constructs an ECA-ResNet-18 model.
|
""" Constructs an ECA-ResNet-18 model.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['ecaresnet18']
|
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='eca'), **kwargs)
|
||||||
block_args = dict(attn_layer='eca')
|
return _create_resnet('ecaresnet18', pretrained, **model_args)
|
||||||
model = ResNet(
|
|
||||||
BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def ecaresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def ecaresnet50(pretrained=False, **kwargs):
|
||||||
"""Constructs an ECA-ResNet-50 model.
|
"""Constructs an ECA-ResNet-50 model.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['ecaresnet50']
|
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='eca'), **kwargs)
|
||||||
block_args = dict(attn_layer='eca')
|
return _create_resnet('ecaresnet50', pretrained, **model_args)
|
||||||
model = ResNet(
|
|
||||||
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def ecaresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def ecaresnet50d(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNet-50-D model with eca.
|
"""Constructs a ResNet-50-D model with eca.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['ecaresnet50d']
|
model_args = dict(
|
||||||
model = ResNet(
|
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||||
Bottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
block_args=dict(attn_layer='eca'), **kwargs)
|
||||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
|
return _create_resnet('ecaresnet50d', pretrained, **model_args)
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def ecaresnet50d_pruned(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def ecaresnet50d_pruned(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNet-50-D model pruned with eca.
|
"""Constructs a ResNet-50-D model pruned with eca.
|
||||||
The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
|
The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
|
||||||
"""
|
"""
|
||||||
variant = 'ecaresnet50d_pruned'
|
model_args = dict(
|
||||||
default_cfg = default_cfgs[variant]
|
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||||
model = ResNet(
|
block_args=dict(attn_layer='eca'), **kwargs)
|
||||||
Bottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **model_args)
|
||||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
model = adapt_model_from_file(model, variant)
|
|
||||||
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def ecaresnetlight(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def ecaresnetlight(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNet-50-D light model with eca.
|
"""Constructs a ResNet-50-D light model with eca.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['ecaresnetlight']
|
model_args = dict(
|
||||||
model = ResNet(
|
block=Bottleneck, layers=[1, 1, 11, 3], stem_width=32, avg_down=True,
|
||||||
Bottleneck, [1, 1, 11, 3], stem_width=32, avg_down=True,
|
block_args=dict(attn_layer='eca'), **kwargs)
|
||||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
|
return _create_resnet('ecaresnetlight', pretrained, **model_args)
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def ecaresnet101d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def ecaresnet101d(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNet-101-D model with eca.
|
"""Constructs a ResNet-101-D model with eca.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['ecaresnet101d']
|
model_args = dict(
|
||||||
model = ResNet(
|
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||||
Bottleneck, [3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
|
block_args=dict(attn_layer='eca'), **kwargs)
|
||||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
|
return _create_resnet('ecaresnet101d', pretrained, **model_args)
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def ecaresnet101d_pruned(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def ecaresnet101d_pruned(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNet-101-D model pruned with eca.
|
"""Constructs a ResNet-101-D model pruned with eca.
|
||||||
The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
|
The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
|
||||||
"""
|
"""
|
||||||
variant = 'ecaresnet101d_pruned'
|
model_args = dict(
|
||||||
default_cfg = default_cfgs[variant]
|
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||||
model = ResNet(
|
block_args=dict(attn_layer='eca'), **kwargs)
|
||||||
Bottleneck, [3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
|
return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **model_args)
|
||||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='eca'), **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
model = adapt_model_from_file(model, variant)
|
|
||||||
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnetblur18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnetblur18(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNet-18 model with blur anti-aliasing
|
"""Constructs a ResNet-18 model with blur anti-aliasing
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnetblur18']
|
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d, **kwargs)
|
||||||
model = ResNet(
|
return _create_resnet('resnetblur18', pretrained, **model_args)
|
||||||
BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnetblur50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def resnetblur50(pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNet-50 model with blur anti-aliasing
|
"""Constructs a ResNet-50 model with blur anti-aliasing
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnetblur50']
|
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs)
|
||||||
model = ResNet(
|
return _create_resnet('resnetblur50', pretrained, **model_args)
|
||||||
Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, aa_layer=BlurPool2d, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
@ -16,6 +16,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from .features import FeatureNet
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .layers import SelectAdaptivePool2d
|
from .layers import SelectAdaptivePool2d
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
@ -100,7 +101,8 @@ class SelecSLSBlock(nn.Module):
|
|||||||
self.conv6 = conv_bn(2 * mid_chs + (0 if is_first else skip_chs), out_chs, 1)
|
self.conv6 = conv_bn(2 * mid_chs + (0 if is_first else skip_chs), out_chs, 1)
|
||||||
|
|
||||||
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
|
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||||
assert isinstance(x, list)
|
if not isinstance(x, list):
|
||||||
|
x = [x]
|
||||||
assert len(x) in [1, 2]
|
assert len(x) in [1, 2]
|
||||||
|
|
||||||
d1 = self.conv1(x[0])
|
d1 = self.conv1(x[0])
|
||||||
@ -163,7 +165,7 @@ class SelecSLS(nn.Module):
|
|||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
x = self.features([x])
|
x = self.features(x)
|
||||||
x = self.head(x[0])
|
x = self.head(x[0])
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -178,6 +180,7 @@ class SelecSLS(nn.Module):
|
|||||||
|
|
||||||
def _create_model(variant, pretrained, model_kwargs):
|
def _create_model(variant, pretrained, model_kwargs):
|
||||||
cfg = {}
|
cfg = {}
|
||||||
|
feature_info = [dict(num_chs=32, reduction=2, module='stem.2')]
|
||||||
if variant.startswith('selecsls42'):
|
if variant.startswith('selecsls42'):
|
||||||
cfg['block'] = SelecSLSBlock
|
cfg['block'] = SelecSLSBlock
|
||||||
# Define configuration of the network after the initial neck
|
# Define configuration of the network after the initial neck
|
||||||
@ -190,7 +193,13 @@ def _create_model(variant, pretrained, model_kwargs):
|
|||||||
(288, 0, 304, 304, True, 2),
|
(288, 0, 304, 304, True, 2),
|
||||||
(304, 304, 304, 480, False, 1),
|
(304, 304, 304, 480, False, 1),
|
||||||
]
|
]
|
||||||
|
feature_info.extend([
|
||||||
|
dict(num_chs=128, reduction=4, module='features.1'),
|
||||||
|
dict(num_chs=288, reduction=8, module='features.3'),
|
||||||
|
dict(num_chs=480, reduction=16, module='features.5'),
|
||||||
|
])
|
||||||
# Head can be replaced with alternative configurations depending on the problem
|
# Head can be replaced with alternative configurations depending on the problem
|
||||||
|
feature_info.append(dict(num_chs=1024, reduction=32, module='head.1'))
|
||||||
if variant == 'selecsls42b':
|
if variant == 'selecsls42b':
|
||||||
cfg['head'] = [
|
cfg['head'] = [
|
||||||
(480, 960, 3, 2),
|
(480, 960, 3, 2),
|
||||||
@ -198,6 +207,7 @@ def _create_model(variant, pretrained, model_kwargs):
|
|||||||
(1024, 1280, 3, 2),
|
(1024, 1280, 3, 2),
|
||||||
(1280, 1024, 1, 1),
|
(1280, 1024, 1, 1),
|
||||||
]
|
]
|
||||||
|
feature_info.append(dict(num_chs=1024, reduction=64, module='head.3'))
|
||||||
cfg['num_features'] = 1024
|
cfg['num_features'] = 1024
|
||||||
else:
|
else:
|
||||||
cfg['head'] = [
|
cfg['head'] = [
|
||||||
@ -206,7 +216,9 @@ def _create_model(variant, pretrained, model_kwargs):
|
|||||||
(1024, 1024, 3, 2),
|
(1024, 1024, 3, 2),
|
||||||
(1024, 1280, 1, 1),
|
(1024, 1280, 1, 1),
|
||||||
]
|
]
|
||||||
|
feature_info.append(dict(num_chs=1280, reduction=64, module='head.3'))
|
||||||
cfg['num_features'] = 1280
|
cfg['num_features'] = 1280
|
||||||
|
|
||||||
elif variant.startswith('selecsls60'):
|
elif variant.startswith('selecsls60'):
|
||||||
cfg['block'] = SelecSLSBlock
|
cfg['block'] = SelecSLSBlock
|
||||||
# Define configuration of the network after the initial neck
|
# Define configuration of the network after the initial neck
|
||||||
@ -222,7 +234,13 @@ def _create_model(variant, pretrained, model_kwargs):
|
|||||||
(288, 288, 288, 288, False, 1),
|
(288, 288, 288, 288, False, 1),
|
||||||
(288, 288, 288, 416, False, 1),
|
(288, 288, 288, 416, False, 1),
|
||||||
]
|
]
|
||||||
|
feature_info.extend([
|
||||||
|
dict(num_chs=128, reduction=4, module='features.1'),
|
||||||
|
dict(num_chs=288, reduction=8, module='features.4'),
|
||||||
|
dict(num_chs=416, reduction=16, module='features.8'),
|
||||||
|
])
|
||||||
# Head can be replaced with alternative configurations depending on the problem
|
# Head can be replaced with alternative configurations depending on the problem
|
||||||
|
feature_info.append(dict(num_chs=1024, reduction=32, module='head.1'))
|
||||||
if variant == 'selecsls60b':
|
if variant == 'selecsls60b':
|
||||||
cfg['head'] = [
|
cfg['head'] = [
|
||||||
(416, 756, 3, 2),
|
(416, 756, 3, 2),
|
||||||
@ -230,6 +248,7 @@ def _create_model(variant, pretrained, model_kwargs):
|
|||||||
(1024, 1280, 3, 2),
|
(1024, 1280, 3, 2),
|
||||||
(1280, 1024, 1, 1),
|
(1280, 1024, 1, 1),
|
||||||
]
|
]
|
||||||
|
feature_info.append(dict(num_chs=1024, reduction=64, module='head.3'))
|
||||||
cfg['num_features'] = 1024
|
cfg['num_features'] = 1024
|
||||||
else:
|
else:
|
||||||
cfg['head'] = [
|
cfg['head'] = [
|
||||||
@ -238,7 +257,9 @@ def _create_model(variant, pretrained, model_kwargs):
|
|||||||
(1024, 1024, 3, 2),
|
(1024, 1024, 3, 2),
|
||||||
(1024, 1280, 1, 1),
|
(1024, 1280, 1, 1),
|
||||||
]
|
]
|
||||||
|
feature_info.append(dict(num_chs=1280, reduction=64, module='head.3'))
|
||||||
cfg['num_features'] = 1280
|
cfg['num_features'] = 1280
|
||||||
|
|
||||||
elif variant == 'selecsls84':
|
elif variant == 'selecsls84':
|
||||||
cfg['block'] = SelecSLSBlock
|
cfg['block'] = SelecSLSBlock
|
||||||
# Define configuration of the network after the initial neck
|
# Define configuration of the network after the initial neck
|
||||||
@ -258,6 +279,11 @@ def _create_model(variant, pretrained, model_kwargs):
|
|||||||
(304, 304, 304, 304, False, 1),
|
(304, 304, 304, 304, False, 1),
|
||||||
(304, 304, 304, 512, False, 1),
|
(304, 304, 304, 512, False, 1),
|
||||||
]
|
]
|
||||||
|
feature_info.extend([
|
||||||
|
dict(num_chs=144, reduction=4, module='features.1'),
|
||||||
|
dict(num_chs=304, reduction=8, module='features.6'),
|
||||||
|
dict(num_chs=512, reduction=16, module='features.12'),
|
||||||
|
])
|
||||||
# Head can be replaced with alternative configurations depending on the problem
|
# Head can be replaced with alternative configurations depending on the problem
|
||||||
cfg['head'] = [
|
cfg['head'] = [
|
||||||
(512, 960, 3, 2),
|
(512, 960, 3, 2),
|
||||||
@ -266,17 +292,35 @@ def _create_model(variant, pretrained, model_kwargs):
|
|||||||
(1024, 1280, 3, 1),
|
(1024, 1280, 3, 1),
|
||||||
]
|
]
|
||||||
cfg['num_features'] = 1280
|
cfg['num_features'] = 1280
|
||||||
|
feature_info.extend([
|
||||||
|
dict(num_chs=1024, reduction=32, module='head.1'),
|
||||||
|
dict(num_chs=1280, reduction=64, module='head.3')
|
||||||
|
])
|
||||||
else:
|
else:
|
||||||
raise ValueError('Invalid net configuration ' + variant + ' !!!')
|
raise ValueError('Invalid net configuration ' + variant + ' !!!')
|
||||||
|
|
||||||
|
load_strict = True
|
||||||
|
features = False
|
||||||
|
out_indices = None
|
||||||
|
if model_kwargs.pop('features_only', False):
|
||||||
|
load_strict = False
|
||||||
|
features = True
|
||||||
|
# this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises?
|
||||||
|
out_indices = model_kwargs.pop('out_indices', (0, 1, 2, 3, 4))
|
||||||
|
model_kwargs.pop('num_classes', 0)
|
||||||
|
|
||||||
model = SelecSLS(cfg, **model_kwargs)
|
model = SelecSLS(cfg, **model_kwargs)
|
||||||
model.default_cfg = default_cfgs[variant]
|
model.default_cfg = default_cfgs[variant]
|
||||||
|
model.feature_info = feature_info
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(
|
load_pretrained(
|
||||||
model,
|
model,
|
||||||
num_classes=model_kwargs.get('num_classes', 0),
|
num_classes=model_kwargs.get('num_classes', 0),
|
||||||
in_chans=model_kwargs.get('in_chans', 3),
|
in_chans=model_kwargs.get('in_chans', 3),
|
||||||
strict=True)
|
strict=load_strict)
|
||||||
|
|
||||||
|
if features:
|
||||||
|
model = FeatureNet(model, out_indices, flatten_sequential=True)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,11 +12,11 @@ import math
|
|||||||
|
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
from .registry import register_model
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .layers import SelectiveKernelConv, ConvBnAct, create_attn
|
from .layers import SelectiveKernelConv, ConvBnAct, create_attn
|
||||||
from .resnet import ResNet
|
from .registry import register_model
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from .resnet import _create_resnet_with_cfg
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url='', **kwargs):
|
def _cfg(url='', **kwargs):
|
||||||
@ -138,101 +138,80 @@ class SelectiveKernelBottleneck(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _create_skresnet(variant, pretrained=False, **kwargs):
|
||||||
|
default_cfg = default_cfgs[variant]
|
||||||
|
return _create_resnet_with_cfg(variant, default_cfg, pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def skresnet18(pretrained=False, **kwargs):
|
||||||
"""Constructs a Selective Kernel ResNet-18 model.
|
"""Constructs a Selective Kernel ResNet-18 model.
|
||||||
|
|
||||||
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
||||||
variation splits the input channels to the selective convolutions to keep param count down.
|
variation splits the input channels to the selective convolutions to keep param count down.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['skresnet18']
|
|
||||||
sk_kwargs = dict(
|
sk_kwargs = dict(
|
||||||
min_attn_channels=16,
|
min_attn_channels=16,
|
||||||
attn_reduction=8,
|
attn_reduction=8,
|
||||||
split_input=True
|
split_input=True)
|
||||||
)
|
model_args = dict(
|
||||||
model = ResNet(
|
block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs),
|
||||||
SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
|
zero_init_last_bn=False, **kwargs)
|
||||||
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
|
return _create_skresnet('skresnet18', pretrained, **model_args)
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def skresnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def skresnet34(pretrained=False, **kwargs):
|
||||||
"""Constructs a Selective Kernel ResNet-34 model.
|
"""Constructs a Selective Kernel ResNet-34 model.
|
||||||
|
|
||||||
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
||||||
variation splits the input channels to the selective convolutions to keep param count down.
|
variation splits the input channels to the selective convolutions to keep param count down.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['skresnet34']
|
|
||||||
sk_kwargs = dict(
|
sk_kwargs = dict(
|
||||||
min_attn_channels=16,
|
min_attn_channels=16,
|
||||||
attn_reduction=8,
|
attn_reduction=8,
|
||||||
split_input=True
|
split_input=True)
|
||||||
)
|
model_args = dict(
|
||||||
model = ResNet(
|
block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
|
||||||
SelectiveKernelBasic, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
zero_init_last_bn=False, **kwargs)
|
||||||
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
|
return _create_skresnet('skresnet34', pretrained, **model_args)
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def skresnet50(pretrained=False, **kwargs):
|
||||||
"""Constructs a Select Kernel ResNet-50 model.
|
"""Constructs a Select Kernel ResNet-50 model.
|
||||||
|
|
||||||
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
||||||
variation splits the input channels to the selective convolutions to keep param count down.
|
variation splits the input channels to the selective convolutions to keep param count down.
|
||||||
"""
|
"""
|
||||||
sk_kwargs = dict(
|
sk_kwargs = dict(split_input=True)
|
||||||
split_input=True,
|
model_args = dict(
|
||||||
)
|
block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
|
||||||
default_cfg = default_cfgs['skresnet50']
|
zero_init_last_bn=False, **kwargs)
|
||||||
model = ResNet(
|
return _create_skresnet('skresnet50', pretrained, **model_args)
|
||||||
SelectiveKernelBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
|
||||||
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def skresnet50d(pretrained=False, **kwargs):
|
||||||
"""Constructs a Select Kernel ResNet-50-D model.
|
"""Constructs a Select Kernel ResNet-50-D model.
|
||||||
|
|
||||||
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
|
||||||
variation splits the input channels to the selective convolutions to keep param count down.
|
variation splits the input channels to the selective convolutions to keep param count down.
|
||||||
"""
|
"""
|
||||||
sk_kwargs = dict(
|
sk_kwargs = dict(split_input=True)
|
||||||
split_input=True,
|
model_args = dict(
|
||||||
)
|
block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||||
default_cfg = default_cfgs['skresnet50d']
|
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
|
||||||
model = ResNet(
|
return _create_skresnet('skresnet50d', pretrained, **model_args)
|
||||||
SelectiveKernelBottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
|
||||||
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs),
|
|
||||||
zero_init_last_bn=False, **kwargs)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def skresnext50_32x4d(pretrained=False, **kwargs):
|
||||||
"""Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to
|
"""Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to
|
||||||
the SKNet-50 model in the Select Kernel Paper
|
the SKNet-50 model in the Select Kernel Paper
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['skresnext50_32x4d']
|
model_args = dict(
|
||||||
model = ResNet(
|
block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
|
||||||
SelectiveKernelBottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
|
zero_init_last_bn=False, **kwargs)
|
||||||
num_classes=num_classes, in_chans=in_chans, zero_init_last_bn=False, **kwargs)
|
return _create_skresnet('skresnext50_32x4d', pretrained, **model_args)
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
return model
|
|
||||||
|
@ -20,6 +20,7 @@ import torch.nn.functional as F
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
|
from .features import FeatureNet
|
||||||
from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, SelectAdaptivePool2d, \
|
from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, SelectAdaptivePool2d, \
|
||||||
create_attn, create_norm_act, get_norm_act_layer
|
create_attn, create_norm_act, get_norm_act_layer
|
||||||
|
|
||||||
@ -296,6 +297,9 @@ class VovNet(nn.Module):
|
|||||||
conv_type(stem_chs[0], stem_chs[1], 3, stride=1, norm_layer=norm_layer),
|
conv_type(stem_chs[0], stem_chs[1], 3, stride=1, norm_layer=norm_layer),
|
||||||
conv_type(stem_chs[1], stem_chs[2], 3, stride=last_stem_stride, norm_layer=norm_layer),
|
conv_type(stem_chs[1], stem_chs[2], 3, stride=last_stem_stride, norm_layer=norm_layer),
|
||||||
])
|
])
|
||||||
|
self.feature_info = [dict(
|
||||||
|
num_chs=stem_chs[1], reduction=2, module=f'stem.{1 if stem_stride == 4 else 2}')]
|
||||||
|
current_stride = stem_stride
|
||||||
|
|
||||||
# OSA stages
|
# OSA stages
|
||||||
in_ch_list = stem_chs[-1:] + stage_out_chs[:-1]
|
in_ch_list = stem_chs[-1:] + stage_out_chs[:-1]
|
||||||
@ -309,6 +313,9 @@ class VovNet(nn.Module):
|
|||||||
downsample=downsample, **stage_args)
|
downsample=downsample, **stage_args)
|
||||||
]
|
]
|
||||||
self.num_features = stage_out_chs[i]
|
self.num_features = stage_out_chs[i]
|
||||||
|
current_stride *= 2 if downsample else 1
|
||||||
|
self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
|
||||||
|
|
||||||
self.stages = nn.Sequential(*stages)
|
self.stages = nn.Sequential(*stages)
|
||||||
|
|
||||||
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
||||||
@ -338,24 +345,24 @@ class VovNet(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def _vovnet(variant, pretrained=False, **kwargs):
|
def _vovnet(variant, pretrained=False, **kwargs):
|
||||||
load_strict = True
|
features = False
|
||||||
model_class = VovNet
|
out_indices = None
|
||||||
if kwargs.pop('features_only', False):
|
if kwargs.pop('features_only', False):
|
||||||
assert False, 'Not Implemented' # TODO
|
features = True
|
||||||
load_strict = False
|
|
||||||
kwargs.pop('num_classes', 0)
|
kwargs.pop('num_classes', 0)
|
||||||
|
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
|
||||||
model_cfg = model_cfgs[variant]
|
model_cfg = model_cfgs[variant]
|
||||||
default_cfg = default_cfgs[variant]
|
model = VovNet(model_cfg, **kwargs)
|
||||||
model = model_class(model_cfg, **kwargs)
|
model.default_cfg = default_cfgs[variant]
|
||||||
model.default_cfg = default_cfg
|
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(
|
load_pretrained(
|
||||||
model, default_cfg,
|
model,
|
||||||
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict)
|
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features)
|
||||||
|
if features:
|
||||||
|
model = FeatureNet(model, out_indices, flatten_sequential=True)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vovnet39a(pretrained=False, **kwargs):
|
def vovnet39a(pretrained=False, **kwargs):
|
||||||
return _vovnet('vovnet39a', pretrained=pretrained, **kwargs)
|
return _vovnet('vovnet39a', pretrained=pretrained, **kwargs)
|
||||||
|
@ -26,6 +26,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
|
from .features import FeatureNet
|
||||||
from .layers import SelectAdaptivePool2d
|
from .layers import SelectAdaptivePool2d
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
|
|
||||||
@ -49,12 +50,12 @@ default_cfgs = {
|
|||||||
|
|
||||||
|
|
||||||
class SeparableConv2d(nn.Module):
|
class SeparableConv2d(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
|
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1):
|
||||||
super(SeparableConv2d, self).__init__()
|
super(SeparableConv2d, self).__init__()
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(
|
self.conv1 = nn.Conv2d(
|
||||||
in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=bias)
|
in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=False)
|
||||||
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
|
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
@ -63,34 +64,26 @@ class SeparableConv2d(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
|
def __init__(self, in_channels, out_channels, reps, strides=1, start_with_relu=True, grow_first=True):
|
||||||
super(Block, self).__init__()
|
super(Block, self).__init__()
|
||||||
|
|
||||||
if out_filters != in_filters or strides != 1:
|
if out_channels != in_channels or strides != 1:
|
||||||
self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, bias=False)
|
self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=strides, bias=False)
|
||||||
self.skipbn = nn.BatchNorm2d(out_filters)
|
self.skipbn = nn.BatchNorm2d(out_channels)
|
||||||
else:
|
else:
|
||||||
self.skip = None
|
self.skip = None
|
||||||
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
rep = []
|
rep = []
|
||||||
|
for i in range(reps):
|
||||||
filters = in_filters
|
if grow_first:
|
||||||
if grow_first:
|
inc = in_channels if i == 0 else out_channels
|
||||||
rep.append(self.relu)
|
outc = out_channels
|
||||||
rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False))
|
else:
|
||||||
rep.append(nn.BatchNorm2d(out_filters))
|
inc = in_channels
|
||||||
filters = out_filters
|
outc = in_channels if i < (reps - 1) else out_channels
|
||||||
|
rep.append(nn.ReLU(inplace=True))
|
||||||
for i in range(reps - 1):
|
rep.append(SeparableConv2d(inc, outc, 3, stride=1, padding=1))
|
||||||
rep.append(self.relu)
|
rep.append(nn.BatchNorm2d(outc))
|
||||||
rep.append(SeparableConv2d(filters, filters, 3, stride=1, padding=1, bias=False))
|
|
||||||
rep.append(nn.BatchNorm2d(filters))
|
|
||||||
|
|
||||||
if not grow_first:
|
|
||||||
rep.append(self.relu)
|
|
||||||
rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False))
|
|
||||||
rep.append(nn.BatchNorm2d(out_filters))
|
|
||||||
|
|
||||||
if not start_with_relu:
|
if not start_with_relu:
|
||||||
rep = rep[1:]
|
rep = rep[1:]
|
||||||
@ -133,34 +126,35 @@ class Xception(nn.Module):
|
|||||||
|
|
||||||
self.conv1 = nn.Conv2d(in_chans, 32, 3, 2, 0, bias=False)
|
self.conv1 = nn.Conv2d(in_chans, 32, 3, 2, 0, bias=False)
|
||||||
self.bn1 = nn.BatchNorm2d(32)
|
self.bn1 = nn.BatchNorm2d(32)
|
||||||
self.relu = nn.ReLU(inplace=True)
|
self.act1 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
|
self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
|
||||||
self.bn2 = nn.BatchNorm2d(64)
|
self.bn2 = nn.BatchNorm2d(64)
|
||||||
# do relu here
|
self.act2 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True)
|
self.block1 = Block(64, 128, 2, 2, start_with_relu=False)
|
||||||
self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True)
|
self.block2 = Block(128, 256, 2, 2)
|
||||||
self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True)
|
self.block3 = Block(256, 728, 2, 2)
|
||||||
|
|
||||||
self.block4 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
self.block4 = Block(728, 728, 3, 1)
|
||||||
self.block5 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
self.block5 = Block(728, 728, 3, 1)
|
||||||
self.block6 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
self.block6 = Block(728, 728, 3, 1)
|
||||||
self.block7 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
self.block7 = Block(728, 728, 3, 1)
|
||||||
|
|
||||||
self.block8 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
self.block8 = Block(728, 728, 3, 1)
|
||||||
self.block9 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
self.block9 = Block(728, 728, 3, 1)
|
||||||
self.block10 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
self.block10 = Block(728, 728, 3, 1)
|
||||||
self.block11 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
|
self.block11 = Block(728, 728, 3, 1)
|
||||||
|
|
||||||
self.block12 = Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False)
|
self.block12 = Block(728, 1024, 2, 2, grow_first=False)
|
||||||
|
|
||||||
self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
|
self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
|
||||||
self.bn3 = nn.BatchNorm2d(1536)
|
self.bn3 = nn.BatchNorm2d(1536)
|
||||||
|
self.act3 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
# do relu here
|
|
||||||
self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1)
|
self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1)
|
||||||
self.bn4 = nn.BatchNorm2d(self.num_features)
|
self.bn4 = nn.BatchNorm2d(self.num_features)
|
||||||
|
self.act4 = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||||
@ -188,11 +182,11 @@ class Xception(nn.Module):
|
|||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.bn1(x)
|
x = self.bn1(x)
|
||||||
x = self.relu(x)
|
x = self.act1(x)
|
||||||
|
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = self.bn2(x)
|
x = self.bn2(x)
|
||||||
x = self.relu(x)
|
x = self.act2(x)
|
||||||
|
|
||||||
x = self.block1(x)
|
x = self.block1(x)
|
||||||
x = self.block2(x)
|
x = self.block2(x)
|
||||||
@ -209,11 +203,11 @@ class Xception(nn.Module):
|
|||||||
|
|
||||||
x = self.conv3(x)
|
x = self.conv3(x)
|
||||||
x = self.bn3(x)
|
x = self.bn3(x)
|
||||||
x = self.relu(x)
|
x = self.act3(x)
|
||||||
|
|
||||||
x = self.conv4(x)
|
x = self.conv4(x)
|
||||||
x = self.bn4(x)
|
x = self.bn4(x)
|
||||||
x = self.relu(x)
|
x = self.act4(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -225,12 +219,28 @@ class Xception(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
def _xception(variant, pretrained=False, **kwargs):
|
||||||
def xception(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
load_strict = True
|
||||||
default_cfg = default_cfgs['xception']
|
features = False
|
||||||
model = Xception(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
out_indices = None
|
||||||
model.default_cfg = default_cfg
|
if kwargs.pop('features_only', False):
|
||||||
|
load_strict = False
|
||||||
|
features = True
|
||||||
|
kwargs.pop('num_classes', 0)
|
||||||
|
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
|
||||||
|
model = Xception(**kwargs)
|
||||||
|
model.default_cfg = default_cfgs[variant]
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
load_pretrained(
|
||||||
|
model,
|
||||||
|
num_classes=kwargs.get('num_classes', 0),
|
||||||
|
in_chans=kwargs.get('in_chans', 3),
|
||||||
|
strict=load_strict)
|
||||||
|
if features:
|
||||||
|
model = FeatureNet(model, out_indices)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def xception(pretrained=False, **kwargs):
|
||||||
|
return _xception('xception', pretrained=pretrained, **kwargs)
|
||||||
|
60
validate.py
60
validate.py
@ -24,9 +24,8 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
has_apex = False
|
has_apex = False
|
||||||
|
|
||||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models,\
|
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
||||||
set_scriptable, set_no_jit
|
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet
|
||||||
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config
|
|
||||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
|
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
@ -76,8 +75,25 @@ parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
|||||||
help='use ema version of weights if present')
|
help='use ema version of weights if present')
|
||||||
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||||
help='convert model torchscript for inference')
|
help='convert model torchscript for inference')
|
||||||
|
parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true',
|
||||||
|
help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance')
|
||||||
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
|
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
|
||||||
help='Output csv file for validation results (summary)')
|
help='Output csv file for validation results (summary)')
|
||||||
|
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
|
||||||
|
help='Real labels JSON file for imagenet evaluation')
|
||||||
|
|
||||||
|
|
||||||
|
def set_jit_legacy():
|
||||||
|
""" Set JIT executor to legacy w/ support for op fusion
|
||||||
|
This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes
|
||||||
|
in the JIT exectutor. These API are not supported so could change.
|
||||||
|
"""
|
||||||
|
#
|
||||||
|
assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!"
|
||||||
|
torch._C._jit_set_profiling_executor(False)
|
||||||
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||||
|
#torch._C._jit_set_texpr_fuser_enabled(True)
|
||||||
|
|
||||||
|
|
||||||
def validate(args):
|
def validate(args):
|
||||||
@ -103,6 +119,8 @@ def validate(args):
|
|||||||
model, test_time_pool = apply_test_time_pool(model, data_config, args)
|
model, test_time_pool = apply_test_time_pool(model, data_config, args)
|
||||||
|
|
||||||
if args.torchscript:
|
if args.torchscript:
|
||||||
|
if args.legacy_jit:
|
||||||
|
set_jit_legacy()
|
||||||
torch.jit.optimized_execution(True)
|
torch.jit.optimized_execution(True)
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
|
|
||||||
@ -116,13 +134,16 @@ def validate(args):
|
|||||||
|
|
||||||
criterion = nn.CrossEntropyLoss().cuda()
|
criterion = nn.CrossEntropyLoss().cuda()
|
||||||
|
|
||||||
#from torchvision.datasets import ImageNet
|
|
||||||
#dataset = ImageNet(args.data, split='val')
|
|
||||||
if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
|
if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
|
||||||
dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
|
dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
|
||||||
else:
|
else:
|
||||||
dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
|
dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
|
||||||
|
|
||||||
|
if args.real_labels:
|
||||||
|
real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
|
||||||
|
else:
|
||||||
|
real_labels = None
|
||||||
|
|
||||||
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
|
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
|
||||||
loader = create_loader(
|
loader = create_loader(
|
||||||
dataset,
|
dataset,
|
||||||
@ -148,7 +169,7 @@ def validate(args):
|
|||||||
input = torch.randn((args.batch_size,) + data_config['input_size']).cuda()
|
input = torch.randn((args.batch_size,) + data_config['input_size']).cuda()
|
||||||
model(input)
|
model(input)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
for i, (input, target) in enumerate(loader):
|
for batch_idx, (input, target) in enumerate(loader):
|
||||||
if args.no_prefetcher:
|
if args.no_prefetcher:
|
||||||
target = target.cuda()
|
target = target.cuda()
|
||||||
input = input.cuda()
|
input = input.cuda()
|
||||||
@ -159,6 +180,9 @@ def validate(args):
|
|||||||
output = model(input)
|
output = model(input)
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
|
|
||||||
|
if real_labels is not None:
|
||||||
|
real_labels.add_result(output)
|
||||||
|
|
||||||
# measure accuracy and record loss
|
# measure accuracy and record loss
|
||||||
acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
|
acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
|
||||||
losses.update(loss.item(), input.size(0))
|
losses.update(loss.item(), input.size(0))
|
||||||
@ -169,25 +193,35 @@ def validate(args):
|
|||||||
batch_time.update(time.time() - end)
|
batch_time.update(time.time() - end)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
|
|
||||||
if i % args.log_freq == 0:
|
if batch_idx % args.log_freq == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
'Test: [{0:>4d}/{1}] '
|
'Test: [{0:>4d}/{1}] '
|
||||||
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
|
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
|
||||||
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
||||||
'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
|
'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
|
||||||
'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
|
'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
|
||||||
i, len(loader), batch_time=batch_time,
|
batch_idx, len(loader), batch_time=batch_time,
|
||||||
rate_avg=input.size(0) / batch_time.avg,
|
rate_avg=input.size(0) / batch_time.avg,
|
||||||
loss=losses, top1=top1, top5=top5))
|
loss=losses, top1=top1, top5=top5))
|
||||||
|
|
||||||
results = OrderedDict(
|
if real_labels is not None:
|
||||||
top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
|
real_top1 = real_labels.get_accuracy(k=1)
|
||||||
top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4),
|
real_top5 = real_labels.get_accuracy(k=5)
|
||||||
|
results = OrderedDict(
|
||||||
|
top1=round(real_top1, 4), top1_err=round(100 - real_top1, 4),
|
||||||
|
top5=round(real_top5, 4), top5_err=round(100 - real_top5, 4),
|
||||||
|
top1_original=round(top1.avg, 4),
|
||||||
|
top5_original=round(top5.avg, 4))
|
||||||
|
else:
|
||||||
|
results = OrderedDict(
|
||||||
|
top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
|
||||||
|
top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4))
|
||||||
|
results.update(OrderedDict(
|
||||||
param_count=round(param_count / 1e6, 2),
|
param_count=round(param_count / 1e6, 2),
|
||||||
img_size=data_config['input_size'][-1],
|
img_size=data_config['input_size'][-1],
|
||||||
cropt_pct=crop_pct,
|
cropt_pct=crop_pct,
|
||||||
interpolation=data_config['interpolation'])
|
interpolation=data_config['interpolation']
|
||||||
|
))
|
||||||
logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
|
logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
|
||||||
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
|
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user