mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add droppath and type hint to Xception.
This commit is contained in:
parent
7f19a4cce7
commit
53a4888328
@ -6,12 +6,13 @@ https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zo
|
|||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import List, Dict, Type, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
from timm.layers import ClassifierHead, ConvNormAct, create_conv2d, get_norm_act_layer
|
from timm.layers import ClassifierHead, ConvNormAct, DropPath, PadType, create_conv2d, get_norm_act_layer
|
||||||
from timm.layers.helpers import to_3tuple
|
from timm.layers.helpers import to_3tuple
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
@ -23,14 +24,14 @@ __all__ = ['XceptionAligned']
|
|||||||
class SeparableConv2d(nn.Module):
|
class SeparableConv2d(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_chs,
|
in_chs: int,
|
||||||
out_chs,
|
out_chs: int,
|
||||||
kernel_size=3,
|
kernel_size: int = 3,
|
||||||
stride=1,
|
stride: int = 1,
|
||||||
dilation=1,
|
dilation: int = 1,
|
||||||
padding='',
|
padding: PadType = '',
|
||||||
act_layer=nn.ReLU,
|
act_layer: Type[nn.Module] = nn.ReLU,
|
||||||
norm_layer=nn.BatchNorm2d,
|
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
|
||||||
):
|
):
|
||||||
super(SeparableConv2d, self).__init__()
|
super(SeparableConv2d, self).__init__()
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
@ -61,15 +62,15 @@ class SeparableConv2d(nn.Module):
|
|||||||
class PreSeparableConv2d(nn.Module):
|
class PreSeparableConv2d(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_chs,
|
in_chs: int,
|
||||||
out_chs,
|
out_chs: int,
|
||||||
kernel_size=3,
|
kernel_size: int = 3,
|
||||||
stride=1,
|
stride: int = 1,
|
||||||
dilation=1,
|
dilation: int = 1,
|
||||||
padding='',
|
padding: PadType = '',
|
||||||
act_layer=nn.ReLU,
|
act_layer: Type[nn.Module] = nn.ReLU,
|
||||||
norm_layer=nn.BatchNorm2d,
|
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
|
||||||
first_act=True,
|
first_act: bool = True,
|
||||||
):
|
):
|
||||||
super(PreSeparableConv2d, self).__init__()
|
super(PreSeparableConv2d, self).__init__()
|
||||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
|
norm_act_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
|
||||||
@ -95,15 +96,16 @@ class PreSeparableConv2d(nn.Module):
|
|||||||
class XceptionModule(nn.Module):
|
class XceptionModule(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_chs,
|
in_chs: int,
|
||||||
out_chs,
|
out_chs: int,
|
||||||
stride=1,
|
stride: int = 1,
|
||||||
dilation=1,
|
dilation: int = 1,
|
||||||
pad_type='',
|
pad_type: PadType = '',
|
||||||
start_with_relu=True,
|
start_with_relu: bool = True,
|
||||||
no_skip=False,
|
no_skip: bool = False,
|
||||||
act_layer=nn.ReLU,
|
act_layer: Type[nn.Module] = nn.ReLU,
|
||||||
norm_layer=None,
|
norm_layer: Optional[Type[nn.Module]] = None,
|
||||||
|
drop_path: Optional[nn.Module] = None
|
||||||
):
|
):
|
||||||
super(XceptionModule, self).__init__()
|
super(XceptionModule, self).__init__()
|
||||||
out_chs = to_3tuple(out_chs)
|
out_chs = to_3tuple(out_chs)
|
||||||
@ -126,12 +128,16 @@ class XceptionModule(nn.Module):
|
|||||||
act_layer=separable_act_layer, norm_layer=norm_layer))
|
act_layer=separable_act_layer, norm_layer=norm_layer))
|
||||||
in_chs = out_chs[i]
|
in_chs = out_chs[i]
|
||||||
|
|
||||||
|
self.drop_path = drop_path
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
skip = x
|
skip = x
|
||||||
x = self.stack(x)
|
x = self.stack(x)
|
||||||
if self.shortcut is not None:
|
if self.shortcut is not None:
|
||||||
skip = self.shortcut(skip)
|
skip = self.shortcut(skip)
|
||||||
if not self.no_skip:
|
if not self.no_skip:
|
||||||
|
if self.drop_path is not None:
|
||||||
|
x = self.drop_path(x)
|
||||||
x = x + skip
|
x = x + skip
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -139,14 +145,15 @@ class XceptionModule(nn.Module):
|
|||||||
class PreXceptionModule(nn.Module):
|
class PreXceptionModule(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_chs,
|
in_chs: int,
|
||||||
out_chs,
|
out_chs: int,
|
||||||
stride=1,
|
stride: int = 1,
|
||||||
dilation=1,
|
dilation: int = 1,
|
||||||
pad_type='',
|
pad_type: PadType = '',
|
||||||
no_skip=False,
|
no_skip: bool = False,
|
||||||
act_layer=nn.ReLU,
|
act_layer: Type[nn.Module] = nn.ReLU,
|
||||||
norm_layer=None,
|
norm_layer: Optional[Type[nn.Module]] = None,
|
||||||
|
drop_path: Optional[nn.Module] = None
|
||||||
):
|
):
|
||||||
super(PreXceptionModule, self).__init__()
|
super(PreXceptionModule, self).__init__()
|
||||||
out_chs = to_3tuple(out_chs)
|
out_chs = to_3tuple(out_chs)
|
||||||
@ -174,11 +181,15 @@ class PreXceptionModule(nn.Module):
|
|||||||
))
|
))
|
||||||
in_chs = out_chs[i]
|
in_chs = out_chs[i]
|
||||||
|
|
||||||
|
self.drop_path = drop_path
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
skip = x
|
skip = x
|
||||||
x = self.stack(x)
|
x = self.stack(x)
|
||||||
if not self.no_skip:
|
if not self.no_skip:
|
||||||
|
if self.drop_path is not None:
|
||||||
|
x = self.drop_path(x)
|
||||||
x = x + self.shortcut(skip)
|
x = x + self.shortcut(skip)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -189,15 +200,16 @@ class XceptionAligned(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
block_cfg,
|
block_cfg: List[Dict],
|
||||||
num_classes=1000,
|
num_classes: int = 1000,
|
||||||
in_chans=3,
|
in_chans: int = 3,
|
||||||
output_stride=32,
|
output_stride: int = 32,
|
||||||
preact=False,
|
preact: bool = False,
|
||||||
act_layer=nn.ReLU,
|
act_layer: Type[nn.Module] = nn.ReLU,
|
||||||
norm_layer=nn.BatchNorm2d,
|
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
|
||||||
drop_rate=0.,
|
drop_rate: float = 0.,
|
||||||
global_pool='avg',
|
drop_path_rate: float = 0.,
|
||||||
|
global_pool: str = 'avg',
|
||||||
):
|
):
|
||||||
super(XceptionAligned, self).__init__()
|
super(XceptionAligned, self).__init__()
|
||||||
assert output_stride in (8, 16, 32)
|
assert output_stride in (8, 16, 32)
|
||||||
@ -217,7 +229,11 @@ class XceptionAligned(nn.Module):
|
|||||||
self.feature_info = []
|
self.feature_info = []
|
||||||
self.blocks = nn.Sequential()
|
self.blocks = nn.Sequential()
|
||||||
module_fn = PreXceptionModule if preact else XceptionModule
|
module_fn = PreXceptionModule if preact else XceptionModule
|
||||||
|
net_num_blocks = len(block_cfg)
|
||||||
|
net_block_idx = 0
|
||||||
for i, b in enumerate(block_cfg):
|
for i, b in enumerate(block_cfg):
|
||||||
|
block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule
|
||||||
|
b['drop_path'] = DropPath(block_dpr) if block_dpr > 0. else None
|
||||||
b['dilation'] = curr_dilation
|
b['dilation'] = curr_dilation
|
||||||
if b['stride'] > 1:
|
if b['stride'] > 1:
|
||||||
name = f'blocks.{i}.stack.conv2' if preact else f'blocks.{i}.stack.act3'
|
name = f'blocks.{i}.stack.conv2' if preact else f'blocks.{i}.stack.act3'
|
||||||
@ -230,6 +246,7 @@ class XceptionAligned(nn.Module):
|
|||||||
curr_stride = next_stride
|
curr_stride = next_stride
|
||||||
self.blocks.add_module(str(i), module_fn(**b, **layer_args))
|
self.blocks.add_module(str(i), module_fn(**b, **layer_args))
|
||||||
self.num_features = self.blocks[-1].out_channels
|
self.num_features = self.blocks[-1].out_channels
|
||||||
|
net_block_idx += 1
|
||||||
|
|
||||||
self.feature_info += [dict(
|
self.feature_info += [dict(
|
||||||
num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]
|
num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user