Add droppath and type hint to Xception.

This commit is contained in:
Li zhuoqun 2024-01-19 23:19:36 +08:00 committed by Ross Wightman
parent 7f19a4cce7
commit 53a4888328

View File

@ -6,12 +6,13 @@ https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zo
Hacked together by / Copyright 2020 Ross Wightman
"""
from functools import partial
from typing import List, Dict, Type, Optional
import torch
import torch.nn as nn
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import ClassifierHead, ConvNormAct, create_conv2d, get_norm_act_layer
from timm.layers import ClassifierHead, ConvNormAct, DropPath, PadType, create_conv2d, get_norm_act_layer
from timm.layers.helpers import to_3tuple
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
@ -23,14 +24,14 @@ __all__ = ['XceptionAligned']
class SeparableConv2d(nn.Module):
def __init__(
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=1,
padding='',
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
in_chs: int,
out_chs: int,
kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
padding: PadType = '',
act_layer: Type[nn.Module] = nn.ReLU,
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
):
super(SeparableConv2d, self).__init__()
self.kernel_size = kernel_size
@ -61,15 +62,15 @@ class SeparableConv2d(nn.Module):
class PreSeparableConv2d(nn.Module):
def __init__(
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=1,
padding='',
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
first_act=True,
in_chs: int,
out_chs: int,
kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
padding: PadType = '',
act_layer: Type[nn.Module] = nn.ReLU,
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
first_act: bool = True,
):
super(PreSeparableConv2d, self).__init__()
norm_act_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
@ -95,15 +96,16 @@ class PreSeparableConv2d(nn.Module):
class XceptionModule(nn.Module):
def __init__(
self,
in_chs,
out_chs,
stride=1,
dilation=1,
pad_type='',
start_with_relu=True,
no_skip=False,
act_layer=nn.ReLU,
norm_layer=None,
in_chs: int,
out_chs: int,
stride: int = 1,
dilation: int = 1,
pad_type: PadType = '',
start_with_relu: bool = True,
no_skip: bool = False,
act_layer: Type[nn.Module] = nn.ReLU,
norm_layer: Optional[Type[nn.Module]] = None,
drop_path: Optional[nn.Module] = None
):
super(XceptionModule, self).__init__()
out_chs = to_3tuple(out_chs)
@ -126,12 +128,16 @@ class XceptionModule(nn.Module):
act_layer=separable_act_layer, norm_layer=norm_layer))
in_chs = out_chs[i]
self.drop_path = drop_path
def forward(self, x):
skip = x
x = self.stack(x)
if self.shortcut is not None:
skip = self.shortcut(skip)
if not self.no_skip:
if self.drop_path is not None:
x = self.drop_path(x)
x = x + skip
return x
@ -139,14 +145,15 @@ class XceptionModule(nn.Module):
class PreXceptionModule(nn.Module):
def __init__(
self,
in_chs,
out_chs,
stride=1,
dilation=1,
pad_type='',
no_skip=False,
act_layer=nn.ReLU,
norm_layer=None,
in_chs: int,
out_chs: int,
stride: int = 1,
dilation: int = 1,
pad_type: PadType = '',
no_skip: bool = False,
act_layer: Type[nn.Module] = nn.ReLU,
norm_layer: Optional[Type[nn.Module]] = None,
drop_path: Optional[nn.Module] = None
):
super(PreXceptionModule, self).__init__()
out_chs = to_3tuple(out_chs)
@ -174,11 +181,15 @@ class PreXceptionModule(nn.Module):
))
in_chs = out_chs[i]
self.drop_path = drop_path
def forward(self, x):
x = self.norm(x)
skip = x
x = self.stack(x)
if not self.no_skip:
if self.drop_path is not None:
x = self.drop_path(x)
x = x + self.shortcut(skip)
return x
@ -189,15 +200,16 @@ class XceptionAligned(nn.Module):
def __init__(
self,
block_cfg,
num_classes=1000,
in_chans=3,
output_stride=32,
preact=False,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
drop_rate=0.,
global_pool='avg',
block_cfg: List[Dict],
num_classes: int = 1000,
in_chans: int = 3,
output_stride: int = 32,
preact: bool = False,
act_layer: Type[nn.Module] = nn.ReLU,
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
drop_rate: float = 0.,
drop_path_rate: float = 0.,
global_pool: str = 'avg',
):
super(XceptionAligned, self).__init__()
assert output_stride in (8, 16, 32)
@ -217,7 +229,11 @@ class XceptionAligned(nn.Module):
self.feature_info = []
self.blocks = nn.Sequential()
module_fn = PreXceptionModule if preact else XceptionModule
net_num_blocks = len(block_cfg)
net_block_idx = 0
for i, b in enumerate(block_cfg):
block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule
b['drop_path'] = DropPath(block_dpr) if block_dpr > 0. else None
b['dilation'] = curr_dilation
if b['stride'] > 1:
name = f'blocks.{i}.stack.conv2' if preact else f'blocks.{i}.stack.act3'
@ -230,6 +246,7 @@ class XceptionAligned(nn.Module):
curr_stride = next_stride
self.blocks.add_module(str(i), module_fn(**b, **layer_args))
self.num_features = self.blocks[-1].out_channels
net_block_idx += 1
self.feature_info += [dict(
num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]