diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 959b0a61..7c926319 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -46,6 +46,7 @@ from .nfnet import * from .pit import * from .pnasnet import * from .pvt_v2 import * +from .rdnet import * from .regnet import * from .repghost import * from .repvit import * diff --git a/timm/models/rdnet.py b/timm/models/rdnet.py new file mode 100644 index 00000000..cf10198a --- /dev/null +++ b/timm/models/rdnet.py @@ -0,0 +1,502 @@ +""" +RDNet +Copyright (c) 2024-present NAVER Cloud Corp. +Apache-2.0 +""" + +from functools import partial +from typing import List, Optional, Tuple, Union, Callable + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import DropPath, NormMlpClassifierHead, ClassifierHead, EffectiveSEModule, \ + make_divisible, get_act_layer, get_norm_layer +from ._builder import build_model_with_cfg +from ._features import feature_take_indices +from ._manipulate import named_apply +from ._registry import register_model, generate_default_cfgs + +__all__ = ["RDNet"] + + +class Block(nn.Module): + def __init__(self, in_chs, inter_chs, out_chs, norm_layer, act_layer): + super().__init__() + self.layers = nn.Sequential( + nn.Conv2d(in_chs, in_chs, groups=in_chs, kernel_size=7, stride=1, padding=3), + norm_layer(in_chs), + nn.Conv2d(in_chs, inter_chs, kernel_size=1, stride=1, padding=0), + act_layer(act_layer), + nn.Conv2d(inter_chs, out_chs, kernel_size=1, stride=1, padding=0), + ) + + def forward(self, x): + return self.layers(x) + + +class BlockESE(nn.Module): + def __init__(self, in_chs, inter_chs, out_chs, norm_layer, act_layer): + super().__init__() + self.layers = nn.Sequential( + nn.Conv2d(in_chs, in_chs, groups=in_chs, kernel_size=7, stride=1, padding=3), + norm_layer(in_chs), + nn.Conv2d(in_chs, inter_chs, kernel_size=1, stride=1, padding=0), + act_layer(), + nn.Conv2d(inter_chs, out_chs, kernel_size=1, stride=1, padding=0), + EffectiveSEModule(out_chs), + ) + + def forward(self, x): + return self.layers(x) + + +class DenseBlock(nn.Module): + def __init__( + self, + num_input_features, + growth_rate, + bottleneck_width_ratio, + drop_path_rate, + drop_rate=0.0, + rand_gather_step_prob=0.0, + block_idx=0, + block_type="Block", + ls_init_value=1e-6, + norm_layer="layernorm2d", + act_layer="gelu", + ): + super().__init__() + self.drop_rate = drop_rate + self.drop_path_rate = drop_path_rate + self.rand_gather_step_prob = rand_gather_step_prob + self.block_idx = block_idx + self.growth_rate = growth_rate + + self.gamma = nn.Parameter(ls_init_value * torch.ones(growth_rate)) if ls_init_value > 0 else None + growth_rate = int(growth_rate) + inter_chs = int(num_input_features * bottleneck_width_ratio / 8) * 8 + + if self.drop_path_rate > 0: + self.drop_path = DropPath(drop_path_rate) + + self.layers = eval(block_type)( + in_chs=num_input_features, + inter_chs=inter_chs, + out_chs=growth_rate, + norm_layer=norm_layer, + act_layer=act_layer, + ) + + def forward(self, x): + if isinstance(x, List): + x = torch.cat(x, 1) + x = self.layers(x) + + if self.gamma is not None: + x = x.mul(self.gamma.reshape(1, -1, 1, 1)) + + if self.drop_path_rate > 0 and self.training: + x = self.drop_path(x) + return x + + +class DenseStage(nn.Sequential): + def __init__(self, num_block, num_input_features, drop_path_rates, growth_rate, **kwargs): + super().__init__() + for i in range(num_block): + layer = DenseBlock( + num_input_features=num_input_features, + growth_rate=growth_rate, + drop_path_rate=drop_path_rates[i], + block_idx=i, + **kwargs, + ) + num_input_features += growth_rate + self.add_module(f"dense_block{i}", layer) + self.num_out_features = num_input_features + + def forward(self, init_feature): + features = [init_feature] + for module in self: + new_feature = module(features) + features.append(new_feature) + return torch.cat(features, 1) + + +class RDNet(nn.Module): + def __init__( + self, + in_chans: int = 3, # timm option [--in-chans] + num_classes: int = 1000, # timm option [--num-classes] + global_pool: str = 'avg', # timm option [--gp] + growth_rates: Tuple[int, ...] = (64, 104, 128, 128, 128, 128, 224), + num_blocks_list: Tuple[int, ...] = (3, 3, 3, 3, 3, 3, 3), + block_type: Tuple[str, ...] = ("Block", "Block", "BlockESE", "BlockESE", "BlockESE", "BlockESE", "BlockESE"), + is_downsample_block: Tuple[bool, ...] = (None, True, True, False, False, False, True), + bottleneck_width_ratio: int = 4, + transition_compression_ratio: float = 0.5, + ls_init_value: float = 1e-6, + stem_type: str = 'patch', + patch_size: int = 4, + num_init_features: int = 64, + head_init_scale: float = 1., + head_norm_first: bool = False, + head_hidden_size: Optional[int] = None, + conv_bias: bool = True, + act_layer: Union[str, Callable] = 'gelu', + norm_layer: str = "layernorm2d", + norm_eps: Optional[float] = None, + drop_rate: float = 0.0, # timm option [--drop: dropout ratio] + drop_path_rate: float = 0.0, # timm option [--drop-path: drop-path ratio] + ): + """ + Args: + in_chans: Number of input image channels. + num_classes: Number of classes for classification head. + global_pool: Global pooling type. + growth_rates: Growth rate at each stage. + num_blocks_list: Number of blocks at each stage. + is_downsample_block: Whether to downsample at each stage. + bottleneck_width_ratio: Bottleneck width ratio (similar to mlp expansion ratio). + transition_compression_ratio: Channel compression ratio of transition layers. + ls_init_value: Init value for Layer Scale, disabled if None. + stem_type: Type of stem. + patch_size: Stem patch size for patch stem. + num_init_features: Number of features of stem. + head_init_scale: Init scaling value for classifier weights and biases. + head_norm_first: Apply normalization before global pool + head. + head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False. + conv_bias: Use bias layers w/ all convolutions. + act_layer: Activation layer type. + norm_layer: Normalization layer type. + norm_eps: Small value to avoid division by zero in normalization. + drop_rate: Head pre-classifier dropout rate. + drop_path_rate: Stochastic depth drop rate. + """ + super().__init__() + assert len(growth_rates) == len(num_blocks_list) == len(is_downsample_block) + act_layer = get_act_layer(act_layer) + norm_layer = get_norm_layer(norm_layer) + if norm_eps is not None: + norm_layer = partial(norm_layer, eps=norm_eps) + + self.num_classes = num_classes + self.drop_rate = drop_rate + + # stem + assert stem_type in ('patch', 'overlap', 'overlap_tiered') + if stem_type == 'patch': + # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 + self.stem = nn.Sequential( + nn.Conv2d(in_chans, num_init_features, kernel_size=patch_size, stride=patch_size, bias=conv_bias), + norm_layer(num_init_features), + ) + stem_stride = patch_size + else: + mid_chs = make_divisible(num_init_features // 2) if 'tiered' in stem_type else num_init_features + self.stem = nn.Sequential( + nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias), + nn.Conv2d(mid_chs, num_init_features, kernel_size=3, stride=2, padding=1, bias=conv_bias), + norm_layer(num_init_features), + ) + stem_stride = 4 + + # features + self.feature_info = [] + self.num_stages = len(growth_rates) + curr_stride = stem_stride + num_features = num_init_features + dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(num_blocks_list)).split(num_blocks_list)] + + dense_stages = [] + for i in range(self.num_stages): + dense_stage_layers = [] + if i != 0: + compressed_num_features = int(num_features * transition_compression_ratio / 8) * 8 + k_size = stride = 1 + if is_downsample_block[i]: + curr_stride *= 2 + k_size = stride = 2 + + dense_stage_layers.append(norm_layer(num_features)) + dense_stage_layers.append( + nn.Conv2d(num_features, compressed_num_features, kernel_size=k_size, stride=stride, padding=0) + ) + num_features = compressed_num_features + + stage = DenseStage( + num_block=num_blocks_list[i], + num_input_features=num_features, + growth_rate=growth_rates[i], + bottleneck_width_ratio=bottleneck_width_ratio, + drop_rate=drop_rate, + drop_path_rates=dp_rates[i], + ls_init_value=ls_init_value, + block_type=block_type[i], + norm_layer=norm_layer, + act_layer=act_layer, + ) + dense_stage_layers.append(stage) + num_features += num_blocks_list[i] * growth_rates[i] + + if i + 1 == self.num_stages or (i + 1 != self.num_stages and is_downsample_block[i + 1]): + self.feature_info += [ + dict( + num_chs=num_features, + reduction=curr_stride, + module=f'dense_stages.{i}', + growth_rate=growth_rates[i], + ) + ] + dense_stages.append(nn.Sequential(*dense_stage_layers)) + self.dense_stages = nn.Sequential(*dense_stages) + self.num_features = num_features + + # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets + # otherwise pool -> norm -> fc, the default RDNet ordering (pretrained NV weights) + if head_norm_first: + assert not head_hidden_size + self.norm_pre = norm_layer(self.num_features) + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + ) + else: + self.norm_pre = nn.Identity() + self.head = NormMlpClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + norm_layer=norm_layer, + ) + self.head_hidden_size = self.head.num_features + + named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) + + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices) + + # forward pass + feat_idx = 0 # stem is index 0 + x = self.stem(x) + if feat_idx in take_indices: + intermediates.append(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + dense_stages = self.dense_stages + else: + dense_stages = self.dense_stages[:max_index] + for stage in dense_stages: + feat_idx += 1 + x = stage(x) + if feat_idx in take_indices: + # NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled + intermediates.append(x) + + if intermediates_only: + return intermediates + + x = self.norm_pre(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.dense_stages) + 1, indices) + self.dense_stages = self.dense_stages[:max_index] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm_pre = nn.Identity() + if prune_head: + self.reset_classifier(0, '') + return take_indices + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head.fc + + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.head.reset(num_classes, global_pool) + + def forward_features(self, x): + x = self.stem(x) + x = self.dense_stages(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=True) if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + @torch.jit.ignore + def group_matcher(self, coarse=False): + assert not coarse, "coarse grouping is not implemented for RDNet" + return dict( + stem=r'^stem', + blocks=r'^dense_stages\.(\d+)', + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.dense_stages: + s.grad_checkpointing = enable + + +def _init_weights(module, name=None, head_init_scale=1.0): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight) + elif isinstance(module, nn.BatchNorm2d): + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.Linear): + nn.init.constant_(module.bias, 0) + if name and 'head.' in name: + module.weight.data.mul_(head_init_scale) + module.bias.data.mul_(head_init_scale) + + +def checkpoint_filter_fn(state_dict, model): + """ Remap NV checkpoints -> timm """ + if 'stem.0.weight' in state_dict: + return state_dict # non-NV checkpoint + if 'model' in state_dict: + state_dict = state_dict['model'] + + out_dict = {} + + for k, v in state_dict.items(): + k = k.replace('stem.stem.', 'stem.') + out_dict[k] = v + + return out_dict + + +def _create_rdnet(variant, pretrained=False, **kwargs): + model = build_model_with_cfg( + RDNet, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), + **kwargs) + return model + + +def _cfg(url='', **kwargs): + return { + "url": url, + "num_classes": 1000, "input_size": (3, 224, 224), "pool_size": (7, 7), + "crop_pct": 0.9, "interpolation": "bicubic", + "mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD, + "first_conv": "stem.0", "classifier": "head.fc", + "paper_ids": "arXiv:2403.19588", + "paper_name": "DenseNets Reloaded: Paradigm Shift Beyond ResNets and ViTs", + "origin_url": "https://github.com/naver-ai/rdnet", + **kwargs, + } + + +default_cfgs = generate_default_cfgs({ + 'rdnet_tiny.nv_in1k': _cfg( + hf_hub_id='naver-ai/rdnet_tiny.nv_in1k'), + 'rdnet_small.nv_in1k': _cfg( + hf_hub_id='naver-ai/rdnet_small.nv_in1k'), + 'rdnet_base.nv_in1k': _cfg( + hf_hub_id='naver-ai/rdnet_base.nv_in1k'), + 'rdnet_large.nv_in1k': _cfg( + hf_hub_id='naver-ai/rdnet_large.nv_in1k'), + 'rdnet_large.nv_in1k_ft_in1k_384': _cfg( + hf_hub_id='naver-ai/rdnet_large.nv_in1k_ft_in1k_384', + input_size=(3, 384, 384), crop_pct=1.0), +}) + + +@register_model +def rdnet_tiny(pretrained=False, **kwargs): + n_layer = 7 + model_args = { + "num_init_features": 64, + "growth_rates": [64] + [104] + [128] * 4 + [224], + "num_blocks_list": [3] * n_layer, + "is_downsample_block": (None, True, True, False, False, False, True), + "transition_compression_ratio": 0.5, + "block_type": ["Block"] + ["Block"] + ["BlockESE"] * 4 + ["BlockESE"], + } + model = _create_rdnet("rdnet_tiny", pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def rdnet_small(pretrained=False, **kwargs): + n_layer = 11 + model_args = { + "num_init_features": 72, + "growth_rates": [64] + [128] + [128] * (n_layer - 4) + [240] * 2, + "num_blocks_list": [3] * n_layer, + "is_downsample_block": (None, True, True, False, False, False, False, False, False, True, False), + "transition_compression_ratio": 0.5, + "block_type": ["Block"] + ["Block"] + ["BlockESE"] * (n_layer - 4) + ["BlockESE"] * 2, + } + model = _create_rdnet("rdnet_small", pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def rdnet_base(pretrained=False, **kwargs): + n_layer = 11 + model_args = { + "num_init_features": 120, + "growth_rates": [96] + [128] + [168] * (n_layer - 4) + [336] * 2, + "num_blocks_list": [3] * n_layer, + "is_downsample_block": (None, True, True, False, False, False, False, False, False, True, False), + "transition_compression_ratio": 0.5, + "block_type": ["Block"] + ["Block"] + ["BlockESE"] * (n_layer - 4) + ["BlockESE"] * 2, + } + model = _create_rdnet("rdnet_base", pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def rdnet_large(pretrained=False, **kwargs): + n_layer = 12 + model_args = { + "num_init_features": 144, + "growth_rates": [128] + [192] + [256] * (n_layer - 4) + [360] * 2, + "num_blocks_list": [3] * n_layer, + "is_downsample_block": (None, True, True, False, False, False, False, False, False, False, True, False), + "transition_compression_ratio": 0.5, + "block_type": ["Block"] + ["Block"] + ["BlockESE"] * (n_layer - 4) + ["BlockESE"] * 2, + } + model = _create_rdnet("rdnet_large", pretrained=pretrained, **dict(model_args, **kwargs)) + return model