From 53a48883280743f5c878897bdfd3d471f1c191e4 Mon Sep 17 00:00:00 2001 From: Li zhuoqun Date: Fri, 19 Jan 2024 23:19:36 +0800 Subject: [PATCH] Add droppath and type hint to Xception. --- timm/models/xception_aligned.py | 105 +++++++++++++++++++------------- 1 file changed, 61 insertions(+), 44 deletions(-) diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index 13adae68..e4b28425 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -6,12 +6,13 @@ https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zo Hacked together by / Copyright 2020 Ross Wightman """ from functools import partial +from typing import List, Dict, Type, Optional import torch import torch.nn as nn from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import ClassifierHead, ConvNormAct, create_conv2d, get_norm_act_layer +from timm.layers import ClassifierHead, ConvNormAct, DropPath, PadType, create_conv2d, get_norm_act_layer from timm.layers.helpers import to_3tuple from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq @@ -23,14 +24,14 @@ __all__ = ['XceptionAligned'] class SeparableConv2d(nn.Module): def __init__( self, - in_chs, - out_chs, - kernel_size=3, - stride=1, - dilation=1, - padding='', - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, + in_chs: int, + out_chs: int, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + padding: PadType = '', + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, ): super(SeparableConv2d, self).__init__() self.kernel_size = kernel_size @@ -61,15 +62,15 @@ class SeparableConv2d(nn.Module): class PreSeparableConv2d(nn.Module): def __init__( self, - in_chs, - out_chs, - kernel_size=3, - stride=1, - dilation=1, - padding='', - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - first_act=True, + in_chs: int, + out_chs: int, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + padding: PadType = '', + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + first_act: bool = True, ): super(PreSeparableConv2d, self).__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer=act_layer) @@ -95,15 +96,16 @@ class PreSeparableConv2d(nn.Module): class XceptionModule(nn.Module): def __init__( self, - in_chs, - out_chs, - stride=1, - dilation=1, - pad_type='', - start_with_relu=True, - no_skip=False, - act_layer=nn.ReLU, - norm_layer=None, + in_chs: int, + out_chs: int, + stride: int = 1, + dilation: int = 1, + pad_type: PadType = '', + start_with_relu: bool = True, + no_skip: bool = False, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Optional[Type[nn.Module]] = None, + drop_path: Optional[nn.Module] = None ): super(XceptionModule, self).__init__() out_chs = to_3tuple(out_chs) @@ -126,12 +128,16 @@ class XceptionModule(nn.Module): act_layer=separable_act_layer, norm_layer=norm_layer)) in_chs = out_chs[i] + self.drop_path = drop_path + def forward(self, x): skip = x x = self.stack(x) if self.shortcut is not None: skip = self.shortcut(skip) if not self.no_skip: + if self.drop_path is not None: + x = self.drop_path(x) x = x + skip return x @@ -139,14 +145,15 @@ class XceptionModule(nn.Module): class PreXceptionModule(nn.Module): def __init__( self, - in_chs, - out_chs, - stride=1, - dilation=1, - pad_type='', - no_skip=False, - act_layer=nn.ReLU, - norm_layer=None, + in_chs: int, + out_chs: int, + stride: int = 1, + dilation: int = 1, + pad_type: PadType = '', + no_skip: bool = False, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Optional[Type[nn.Module]] = None, + drop_path: Optional[nn.Module] = None ): super(PreXceptionModule, self).__init__() out_chs = to_3tuple(out_chs) @@ -174,11 +181,15 @@ class PreXceptionModule(nn.Module): )) in_chs = out_chs[i] + self.drop_path = drop_path + def forward(self, x): x = self.norm(x) skip = x x = self.stack(x) if not self.no_skip: + if self.drop_path is not None: + x = self.drop_path(x) x = x + self.shortcut(skip) return x @@ -189,15 +200,16 @@ class XceptionAligned(nn.Module): def __init__( self, - block_cfg, - num_classes=1000, - in_chans=3, - output_stride=32, - preact=False, - act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, - drop_rate=0., - global_pool='avg', + block_cfg: List[Dict], + num_classes: int = 1000, + in_chans: int = 3, + output_stride: int = 32, + preact: bool = False, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + drop_rate: float = 0., + drop_path_rate: float = 0., + global_pool: str = 'avg', ): super(XceptionAligned, self).__init__() assert output_stride in (8, 16, 32) @@ -217,7 +229,11 @@ class XceptionAligned(nn.Module): self.feature_info = [] self.blocks = nn.Sequential() module_fn = PreXceptionModule if preact else XceptionModule + net_num_blocks = len(block_cfg) + net_block_idx = 0 for i, b in enumerate(block_cfg): + block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule + b['drop_path'] = DropPath(block_dpr) if block_dpr > 0. else None b['dilation'] = curr_dilation if b['stride'] > 1: name = f'blocks.{i}.stack.conv2' if preact else f'blocks.{i}.stack.act3' @@ -230,6 +246,7 @@ class XceptionAligned(nn.Module): curr_stride = next_stride self.blocks.add_module(str(i), module_fn(**b, **layer_args)) self.num_features = self.blocks[-1].out_channels + net_block_idx += 1 self.feature_info += [dict( num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]