mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add droppath and type hint to Xception.
This commit is contained in:
parent
7f19a4cce7
commit
53a4888328
@ -6,12 +6,13 @@ https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zo
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
from functools import partial
|
||||
from typing import List, Dict, Type, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import ClassifierHead, ConvNormAct, create_conv2d, get_norm_act_layer
|
||||
from timm.layers import ClassifierHead, ConvNormAct, DropPath, PadType, create_conv2d, get_norm_act_layer
|
||||
from timm.layers.helpers import to_3tuple
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
@ -23,14 +24,14 @@ __all__ = ['XceptionAligned']
|
||||
class SeparableConv2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
padding='',
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
padding: PadType = '',
|
||||
act_layer: Type[nn.Module] = nn.ReLU,
|
||||
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
|
||||
):
|
||||
super(SeparableConv2d, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
@ -61,15 +62,15 @@ class SeparableConv2d(nn.Module):
|
||||
class PreSeparableConv2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
padding='',
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
first_act=True,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
padding: PadType = '',
|
||||
act_layer: Type[nn.Module] = nn.ReLU,
|
||||
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
|
||||
first_act: bool = True,
|
||||
):
|
||||
super(PreSeparableConv2d, self).__init__()
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
|
||||
@ -95,15 +96,16 @@ class PreSeparableConv2d(nn.Module):
|
||||
class XceptionModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
pad_type='',
|
||||
start_with_relu=True,
|
||||
no_skip=False,
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=None,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
pad_type: PadType = '',
|
||||
start_with_relu: bool = True,
|
||||
no_skip: bool = False,
|
||||
act_layer: Type[nn.Module] = nn.ReLU,
|
||||
norm_layer: Optional[Type[nn.Module]] = None,
|
||||
drop_path: Optional[nn.Module] = None
|
||||
):
|
||||
super(XceptionModule, self).__init__()
|
||||
out_chs = to_3tuple(out_chs)
|
||||
@ -126,12 +128,16 @@ class XceptionModule(nn.Module):
|
||||
act_layer=separable_act_layer, norm_layer=norm_layer))
|
||||
in_chs = out_chs[i]
|
||||
|
||||
self.drop_path = drop_path
|
||||
|
||||
def forward(self, x):
|
||||
skip = x
|
||||
x = self.stack(x)
|
||||
if self.shortcut is not None:
|
||||
skip = self.shortcut(skip)
|
||||
if not self.no_skip:
|
||||
if self.drop_path is not None:
|
||||
x = self.drop_path(x)
|
||||
x = x + skip
|
||||
return x
|
||||
|
||||
@ -139,14 +145,15 @@ class XceptionModule(nn.Module):
|
||||
class PreXceptionModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
pad_type='',
|
||||
no_skip=False,
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=None,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
pad_type: PadType = '',
|
||||
no_skip: bool = False,
|
||||
act_layer: Type[nn.Module] = nn.ReLU,
|
||||
norm_layer: Optional[Type[nn.Module]] = None,
|
||||
drop_path: Optional[nn.Module] = None
|
||||
):
|
||||
super(PreXceptionModule, self).__init__()
|
||||
out_chs = to_3tuple(out_chs)
|
||||
@ -174,11 +181,15 @@ class PreXceptionModule(nn.Module):
|
||||
))
|
||||
in_chs = out_chs[i]
|
||||
|
||||
self.drop_path = drop_path
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
skip = x
|
||||
x = self.stack(x)
|
||||
if not self.no_skip:
|
||||
if self.drop_path is not None:
|
||||
x = self.drop_path(x)
|
||||
x = x + self.shortcut(skip)
|
||||
return x
|
||||
|
||||
@ -189,15 +200,16 @@ class XceptionAligned(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_cfg,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
output_stride=32,
|
||||
preact=False,
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
drop_rate=0.,
|
||||
global_pool='avg',
|
||||
block_cfg: List[Dict],
|
||||
num_classes: int = 1000,
|
||||
in_chans: int = 3,
|
||||
output_stride: int = 32,
|
||||
preact: bool = False,
|
||||
act_layer: Type[nn.Module] = nn.ReLU,
|
||||
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
|
||||
drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
global_pool: str = 'avg',
|
||||
):
|
||||
super(XceptionAligned, self).__init__()
|
||||
assert output_stride in (8, 16, 32)
|
||||
@ -217,7 +229,11 @@ class XceptionAligned(nn.Module):
|
||||
self.feature_info = []
|
||||
self.blocks = nn.Sequential()
|
||||
module_fn = PreXceptionModule if preact else XceptionModule
|
||||
net_num_blocks = len(block_cfg)
|
||||
net_block_idx = 0
|
||||
for i, b in enumerate(block_cfg):
|
||||
block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule
|
||||
b['drop_path'] = DropPath(block_dpr) if block_dpr > 0. else None
|
||||
b['dilation'] = curr_dilation
|
||||
if b['stride'] > 1:
|
||||
name = f'blocks.{i}.stack.conv2' if preact else f'blocks.{i}.stack.act3'
|
||||
@ -230,6 +246,7 @@ class XceptionAligned(nn.Module):
|
||||
curr_stride = next_stride
|
||||
self.blocks.add_module(str(i), module_fn(**b, **layer_args))
|
||||
self.num_features = self.blocks[-1].out_channels
|
||||
net_block_idx += 1
|
||||
|
||||
self.feature_info += [dict(
|
||||
num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]
|
||||
|
Loading…
x
Reference in New Issue
Block a user