""" TResNet: High Performance GPU-Dedicated Architecture https://arxiv.org/pdf/2003.13630.pdf Original model: https://github.com/mrT23/TResNet """ from collections import OrderedDict from functools import partial from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs, register_model_deprecations __all__ = ['TResNet'] # model_registry will add each entrypoint fn to this class BasicBlock(nn.Module): expansion = 1 def __init__( self, inplanes, planes, stride=1, downsample=None, use_se=True, aa_layer=None, drop_path_rate=0. ): super(BasicBlock, self).__init__() self.downsample = downsample self.stride = stride act_layer = partial(nn.LeakyReLU, negative_slope=1e-3) self.conv1 = ConvNormAct(inplanes, planes, kernel_size=3, stride=stride, act_layer=act_layer, aa_layer=aa_layer) self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False) self.act = nn.ReLU(inplace=True) rd_chs = max(planes * self.expansion // 4, 64) self.se = SEModule(planes * self.expansion, rd_channels=rd_chs) if use_se else None self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() def forward(self, x): if self.downsample is not None: shortcut = self.downsample(x) else: shortcut = x out = self.conv1(x) out = self.conv2(out) if self.se is not None: out = self.se(out) out = self.drop_path(out) + shortcut out = self.act(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__( self, inplanes, planes, stride=1, downsample=None, use_se=True, act_layer=None, aa_layer=None, drop_path_rate=0., ): super(Bottleneck, self).__init__() self.downsample = downsample self.stride = stride act_layer = act_layer or partial(nn.LeakyReLU, negative_slope=1e-3) self.conv1 = ConvNormAct( inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer) self.conv2 = ConvNormAct( planes, planes, kernel_size=3, stride=stride, act_layer=act_layer, aa_layer=aa_layer) reduction_chs = max(planes * self.expansion // 8, 64) self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None self.conv3 = ConvNormAct( planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.act = nn.ReLU(inplace=True) def forward(self, x): if self.downsample is not None: shortcut = self.downsample(x) else: shortcut = x out = self.conv1(x) out = self.conv2(out) if self.se is not None: out = self.se(out) out = self.conv3(out) out = self.drop_path(out) + shortcut out = self.act(out) return out class TResNet(nn.Module): def __init__( self, layers, in_chans=3, num_classes=1000, width_factor=1.0, v2=False, global_pool='fast', drop_rate=0., drop_path_rate=0., ): self.num_classes = num_classes self.drop_rate = drop_rate self.grad_checkpointing = False super(TResNet, self).__init__() aa_layer = BlurPool2d act_layer = nn.LeakyReLU # TResnet stages self.inplanes = int(64 * width_factor) self.planes = int(64 * width_factor) if v2: self.inplanes = self.inplanes // 8 * 8 self.planes = self.planes // 8 * 8 dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)] conv1 = ConvNormAct(in_chans * 16, self.planes, stride=1, kernel_size=3, act_layer=act_layer) layer1 = self._make_layer( Bottleneck if v2 else BasicBlock, self.planes, layers[0], stride=1, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[0]) layer2 = self._make_layer( Bottleneck if v2 else BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[1]) layer3 = self._make_layer( Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[2]) layer4 = self._make_layer( Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False, aa_layer=aa_layer, drop_path_rate=dpr[3]) # body self.body = nn.Sequential(OrderedDict([ ('s2d', SpaceToDepth()), ('conv1', conv1), ('layer1', layer1), ('layer2', layer2), ('layer3', layer3), ('layer4', layer4), ])) self.feature_info = [ dict(num_chs=self.planes, reduction=2, module=''), # Not with S2D? dict(num_chs=self.planes * (Bottleneck.expansion if v2 else 1), reduction=4, module='body.layer1'), dict(num_chs=self.planes * 2 * (Bottleneck.expansion if v2 else 1), reduction=8, module='body.layer2'), dict(num_chs=self.planes * 4 * Bottleneck.expansion, reduction=16, module='body.layer3'), dict(num_chs=self.planes * 8 * Bottleneck.expansion, reduction=32, module='body.layer4'), ] # head self.num_features = self.head_hidden_size = (self.planes * 8) * Bottleneck.expansion self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) # model initialization for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01) # residual connections special initialization for m in self.modules(): if isinstance(m, BasicBlock): nn.init.zeros_(m.conv2.bn.weight) if isinstance(m, Bottleneck): nn.init.zeros_(m.conv3.bn.weight) def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=None, drop_path_rate=0.): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: layers = [] if stride == 2: # avg pooling before 1x1 conv layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False)) layers += [ConvNormAct( self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False)] downsample = nn.Sequential(*layers) layers = [] for i in range(blocks): layers.append(block( self.inplanes, planes, stride=stride if i == 0 else 1, downsample=downsample if i == 0 else None, use_se=use_se, aa_layer=aa_layer, drop_path_rate=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate, )) self.inplanes = planes * block.expansion return nn.Sequential(*layers) @torch.jit.ignore def group_matcher(self, coarse=False): matcher = dict(stem=r'^body\.conv1', blocks=r'^body\.layer(\d+)' if coarse else r'^body\.layer(\d+)\.(\d+)') return matcher @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self) -> nn.Module: return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) def forward_intermediates( self, x: torch.Tensor, indices: Optional[Union[int, List[int]]] = None, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: """ Forward features that returns intermediates. Args: x: Input image tensor indices: Take last n blocks if int, all if None, select matching indices if sequence norm: Apply norm layer to compatible intermediates stop_early: Stop iterating over blocks when last desired intermediate hit output_fmt: Shape of intermediate feature outputs intermediates_only: Only return intermediate features Returns: """ assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' intermediates = [] stage_ends = [1, 2, 3, 4, 5] take_indices, max_index = feature_take_indices(len(stage_ends), indices) take_indices = [stage_ends[i] for i in take_indices] max_index = stage_ends[max_index] # forward pass if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript stages = self.body else: stages = self.body[:max_index + 1] for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: intermediates.append(x) if intermediates_only: return intermediates return x, intermediates def prune_intermediate_layers( self, indices: Union[int, List[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): """ Prune layers not required for specified intermediates. """ stage_ends = [1, 2, 3, 4, 5] take_indices, max_index = feature_take_indices(len(stage_ends), indices) max_index = stage_ends[max_index] self.body = self.body[:max_index + 1] # truncate blocks w/ stem as idx 0 if prune_head: self.reset_classifier(0, '') return take_indices def forward_features(self, x): if self.grad_checkpointing and not torch.jit.is_scripting(): x = self.body.s2d(x) x = self.body.conv1(x) x = checkpoint_seq([ self.body.layer1, self.body.layer2, self.body.layer3, self.body.layer4], x, flatten=True) else: x = self.body(x) return x def forward_head(self, x, pre_logits: bool = False): return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) x = self.forward_head(x) return x def checkpoint_filter_fn(state_dict, model): if 'body.conv1.conv.weight' in state_dict: return state_dict import re state_dict = state_dict.get('model', state_dict) state_dict = state_dict.get('state_dict', state_dict) out_dict = {} for k, v in state_dict.items(): k = re.sub(r'conv(\d+)\.0.0', lambda x: f'conv{int(x.group(1))}.conv', k) k = re.sub(r'conv(\d+)\.0.1', lambda x: f'conv{int(x.group(1))}.bn', k) k = re.sub(r'conv(\d+)\.0', lambda x: f'conv{int(x.group(1))}.conv', k) k = re.sub(r'conv(\d+)\.1', lambda x: f'conv{int(x.group(1))}.bn', k) k = re.sub(r'downsample\.(\d+)\.0', lambda x: f'downsample.{int(x.group(1))}.conv', k) k = re.sub(r'downsample\.(\d+)\.1', lambda x: f'downsample.{int(x.group(1))}.bn', k) if k.endswith('bn.weight'): # convert weight from inplace_abn to batchnorm v = v.abs().add(1e-5) out_dict[k] = v return out_dict def _create_tresnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( TResNet, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True), **kwargs, ) def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': (0., 0., 0.), 'std': (1., 1., 1.), 'first_conv': 'body.conv1.conv', 'classifier': 'head.fc', **kwargs } default_cfgs = generate_default_cfgs({ 'tresnet_m.miil_in21k_ft_in1k': _cfg(hf_hub_id='timm/'), 'tresnet_m.miil_in21k': _cfg(hf_hub_id='timm/', num_classes=11221), 'tresnet_m.miil_in1k': _cfg(hf_hub_id='timm/'), 'tresnet_l.miil_in1k': _cfg(hf_hub_id='timm/'), 'tresnet_xl.miil_in1k': _cfg(hf_hub_id='timm/'), 'tresnet_m.miil_in1k_448': _cfg( input_size=(3, 448, 448), pool_size=(14, 14), hf_hub_id='timm/'), 'tresnet_l.miil_in1k_448': _cfg( input_size=(3, 448, 448), pool_size=(14, 14), hf_hub_id='timm/'), 'tresnet_xl.miil_in1k_448': _cfg( input_size=(3, 448, 448), pool_size=(14, 14), hf_hub_id='timm/'), 'tresnet_v2_l.miil_in21k_ft_in1k': _cfg(hf_hub_id='timm/'), 'tresnet_v2_l.miil_in21k': _cfg(hf_hub_id='timm/', num_classes=11221), }) @register_model def tresnet_m(pretrained=False, **kwargs) -> TResNet: model_args = dict(layers=[3, 4, 11, 3]) return _create_tresnet('tresnet_m', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def tresnet_l(pretrained=False, **kwargs) -> TResNet: model_args = dict(layers=[4, 5, 18, 3], width_factor=1.2) return _create_tresnet('tresnet_l', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def tresnet_xl(pretrained=False, **kwargs) -> TResNet: model_args = dict(layers=[4, 5, 24, 3], width_factor=1.3) return _create_tresnet('tresnet_xl', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def tresnet_v2_l(pretrained=False, **kwargs) -> TResNet: model_args = dict(layers=[3, 4, 23, 3], width_factor=1.0, v2=True) return _create_tresnet('tresnet_v2_l', pretrained=pretrained, **dict(model_args, **kwargs)) register_model_deprecations(__name__, { 'tresnet_m_miil_in21k': 'tresnet_m.miil_in21k', 'tresnet_m_448': 'tresnet_m.miil_in1k_448', 'tresnet_l_448': 'tresnet_l.miil_in1k_448', 'tresnet_xl_448': 'tresnet_xl.miil_in1k_448', })