Add droppath and type hint to Xception.

This commit is contained in:
Li zhuoqun 2024-01-19 23:19:36 +08:00 committed by Ross Wightman
parent 7f19a4cce7
commit 53a4888328

View File

@ -6,12 +6,13 @@ https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zo
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
from functools import partial from functools import partial
from typing import List, Dict, Type, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import ClassifierHead, ConvNormAct, create_conv2d, get_norm_act_layer from timm.layers import ClassifierHead, ConvNormAct, DropPath, PadType, create_conv2d, get_norm_act_layer
from timm.layers.helpers import to_3tuple from timm.layers.helpers import to_3tuple
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
@ -23,14 +24,14 @@ __all__ = ['XceptionAligned']
class SeparableConv2d(nn.Module): class SeparableConv2d(nn.Module):
def __init__( def __init__(
self, self,
in_chs, in_chs: int,
out_chs, out_chs: int,
kernel_size=3, kernel_size: int = 3,
stride=1, stride: int = 1,
dilation=1, dilation: int = 1,
padding='', padding: PadType = '',
act_layer=nn.ReLU, act_layer: Type[nn.Module] = nn.ReLU,
norm_layer=nn.BatchNorm2d, norm_layer: Type[nn.Module] = nn.BatchNorm2d,
): ):
super(SeparableConv2d, self).__init__() super(SeparableConv2d, self).__init__()
self.kernel_size = kernel_size self.kernel_size = kernel_size
@ -61,15 +62,15 @@ class SeparableConv2d(nn.Module):
class PreSeparableConv2d(nn.Module): class PreSeparableConv2d(nn.Module):
def __init__( def __init__(
self, self,
in_chs, in_chs: int,
out_chs, out_chs: int,
kernel_size=3, kernel_size: int = 3,
stride=1, stride: int = 1,
dilation=1, dilation: int = 1,
padding='', padding: PadType = '',
act_layer=nn.ReLU, act_layer: Type[nn.Module] = nn.ReLU,
norm_layer=nn.BatchNorm2d, norm_layer: Type[nn.Module] = nn.BatchNorm2d,
first_act=True, first_act: bool = True,
): ):
super(PreSeparableConv2d, self).__init__() super(PreSeparableConv2d, self).__init__()
norm_act_layer = get_norm_act_layer(norm_layer, act_layer=act_layer) norm_act_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
@ -95,15 +96,16 @@ class PreSeparableConv2d(nn.Module):
class XceptionModule(nn.Module): class XceptionModule(nn.Module):
def __init__( def __init__(
self, self,
in_chs, in_chs: int,
out_chs, out_chs: int,
stride=1, stride: int = 1,
dilation=1, dilation: int = 1,
pad_type='', pad_type: PadType = '',
start_with_relu=True, start_with_relu: bool = True,
no_skip=False, no_skip: bool = False,
act_layer=nn.ReLU, act_layer: Type[nn.Module] = nn.ReLU,
norm_layer=None, norm_layer: Optional[Type[nn.Module]] = None,
drop_path: Optional[nn.Module] = None
): ):
super(XceptionModule, self).__init__() super(XceptionModule, self).__init__()
out_chs = to_3tuple(out_chs) out_chs = to_3tuple(out_chs)
@ -126,12 +128,16 @@ class XceptionModule(nn.Module):
act_layer=separable_act_layer, norm_layer=norm_layer)) act_layer=separable_act_layer, norm_layer=norm_layer))
in_chs = out_chs[i] in_chs = out_chs[i]
self.drop_path = drop_path
def forward(self, x): def forward(self, x):
skip = x skip = x
x = self.stack(x) x = self.stack(x)
if self.shortcut is not None: if self.shortcut is not None:
skip = self.shortcut(skip) skip = self.shortcut(skip)
if not self.no_skip: if not self.no_skip:
if self.drop_path is not None:
x = self.drop_path(x)
x = x + skip x = x + skip
return x return x
@ -139,14 +145,15 @@ class XceptionModule(nn.Module):
class PreXceptionModule(nn.Module): class PreXceptionModule(nn.Module):
def __init__( def __init__(
self, self,
in_chs, in_chs: int,
out_chs, out_chs: int,
stride=1, stride: int = 1,
dilation=1, dilation: int = 1,
pad_type='', pad_type: PadType = '',
no_skip=False, no_skip: bool = False,
act_layer=nn.ReLU, act_layer: Type[nn.Module] = nn.ReLU,
norm_layer=None, norm_layer: Optional[Type[nn.Module]] = None,
drop_path: Optional[nn.Module] = None
): ):
super(PreXceptionModule, self).__init__() super(PreXceptionModule, self).__init__()
out_chs = to_3tuple(out_chs) out_chs = to_3tuple(out_chs)
@ -174,11 +181,15 @@ class PreXceptionModule(nn.Module):
)) ))
in_chs = out_chs[i] in_chs = out_chs[i]
self.drop_path = drop_path
def forward(self, x): def forward(self, x):
x = self.norm(x) x = self.norm(x)
skip = x skip = x
x = self.stack(x) x = self.stack(x)
if not self.no_skip: if not self.no_skip:
if self.drop_path is not None:
x = self.drop_path(x)
x = x + self.shortcut(skip) x = x + self.shortcut(skip)
return x return x
@ -189,15 +200,16 @@ class XceptionAligned(nn.Module):
def __init__( def __init__(
self, self,
block_cfg, block_cfg: List[Dict],
num_classes=1000, num_classes: int = 1000,
in_chans=3, in_chans: int = 3,
output_stride=32, output_stride: int = 32,
preact=False, preact: bool = False,
act_layer=nn.ReLU, act_layer: Type[nn.Module] = nn.ReLU,
norm_layer=nn.BatchNorm2d, norm_layer: Type[nn.Module] = nn.BatchNorm2d,
drop_rate=0., drop_rate: float = 0.,
global_pool='avg', drop_path_rate: float = 0.,
global_pool: str = 'avg',
): ):
super(XceptionAligned, self).__init__() super(XceptionAligned, self).__init__()
assert output_stride in (8, 16, 32) assert output_stride in (8, 16, 32)
@ -217,7 +229,11 @@ class XceptionAligned(nn.Module):
self.feature_info = [] self.feature_info = []
self.blocks = nn.Sequential() self.blocks = nn.Sequential()
module_fn = PreXceptionModule if preact else XceptionModule module_fn = PreXceptionModule if preact else XceptionModule
net_num_blocks = len(block_cfg)
net_block_idx = 0
for i, b in enumerate(block_cfg): for i, b in enumerate(block_cfg):
block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule
b['drop_path'] = DropPath(block_dpr) if block_dpr > 0. else None
b['dilation'] = curr_dilation b['dilation'] = curr_dilation
if b['stride'] > 1: if b['stride'] > 1:
name = f'blocks.{i}.stack.conv2' if preact else f'blocks.{i}.stack.act3' name = f'blocks.{i}.stack.conv2' if preact else f'blocks.{i}.stack.act3'
@ -230,6 +246,7 @@ class XceptionAligned(nn.Module):
curr_stride = next_stride curr_stride = next_stride
self.blocks.add_module(str(i), module_fn(**b, **layer_args)) self.blocks.add_module(str(i), module_fn(**b, **layer_args))
self.num_features = self.blocks[-1].out_channels self.num_features = self.blocks[-1].out_channels
net_block_idx += 1
self.feature_info += [dict( self.feature_info += [dict(
num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))] num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]