mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge branch 'master' into cait
This commit is contained in:
commit
5fcddb96a8
@ -48,7 +48,7 @@ parser = argparse.ArgumentParser(description='PyTorch Benchmark')
|
||||
parser.add_argument('--model-list', metavar='NAME', default='',
|
||||
help='txt file based list of model names to benchmark')
|
||||
parser.add_argument('--bench', default='both', type=str,
|
||||
help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'inference'")
|
||||
help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'both'")
|
||||
parser.add_argument('--detail', action='store_true', default=False,
|
||||
help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
|
||||
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
|
||||
|
@ -15,7 +15,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
||||
# transformer models don't support many of the spatial / feature based model functionalities
|
||||
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*']
|
||||
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', 'mixer_*']
|
||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||
|
||||
# exclude models that cause specific test failures
|
||||
|
@ -15,6 +15,7 @@ from .hrnet import *
|
||||
from .inception_resnet_v2 import *
|
||||
from .inception_v3 import *
|
||||
from .inception_v4 import *
|
||||
from .mlp_mixer import *
|
||||
from .mobilenetv3 import *
|
||||
from .nasnet import *
|
||||
from .nfnet import *
|
||||
|
@ -294,6 +294,8 @@ class SelfAttnBlock(nn.Module):
|
||||
def init_weights(self, zero_init_last_bn=False):
|
||||
if zero_init_last_bn:
|
||||
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
||||
if hasattr(self.self_attn, 'reset_parameters'):
|
||||
self.self_attn.reset_parameters()
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut(x)
|
||||
|
@ -62,9 +62,9 @@ class DlaBasic(nn.Module):
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
if residual is None:
|
||||
residual = x
|
||||
def forward(self, x, shortcut=None):
|
||||
if shortcut is None:
|
||||
shortcut = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
@ -73,7 +73,7 @@ class DlaBasic(nn.Module):
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
out += residual
|
||||
out += shortcut
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
@ -99,9 +99,9 @@ class DlaBottleneck(nn.Module):
|
||||
self.bn3 = nn.BatchNorm2d(outplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
if residual is None:
|
||||
residual = x
|
||||
def forward(self, x, shortcut=None):
|
||||
if shortcut is None:
|
||||
shortcut = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
@ -114,7 +114,7 @@ class DlaBottleneck(nn.Module):
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
out += residual
|
||||
out += shortcut
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
@ -154,9 +154,9 @@ class DlaBottle2neck(nn.Module):
|
||||
self.bn3 = nn.BatchNorm2d(outplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
if residual is None:
|
||||
residual = x
|
||||
def forward(self, x, shortcut=None):
|
||||
if shortcut is None:
|
||||
shortcut = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
@ -177,26 +177,26 @@ class DlaBottle2neck(nn.Module):
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
out += residual
|
||||
out += shortcut
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class DlaRoot(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, residual):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, shortcut):
|
||||
super(DlaRoot, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.residual = residual
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, *x):
|
||||
children = x
|
||||
x = self.conv(torch.cat(x, 1))
|
||||
x = self.bn(x)
|
||||
if self.residual:
|
||||
if self.shortcut:
|
||||
x += children[0]
|
||||
x = self.relu(x)
|
||||
|
||||
@ -206,7 +206,7 @@ class DlaRoot(nn.Module):
|
||||
class DlaTree(nn.Module):
|
||||
def __init__(self, levels, block, in_channels, out_channels, stride=1,
|
||||
dilation=1, cardinality=1, base_width=64,
|
||||
level_root=False, root_dim=0, root_kernel_size=1, root_residual=False):
|
||||
level_root=False, root_dim=0, root_kernel_size=1, root_shortcut=False):
|
||||
super(DlaTree, self).__init__()
|
||||
if root_dim == 0:
|
||||
root_dim = 2 * out_channels
|
||||
@ -226,24 +226,24 @@ class DlaTree(nn.Module):
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels))
|
||||
else:
|
||||
cargs.update(dict(root_kernel_size=root_kernel_size, root_residual=root_residual))
|
||||
cargs.update(dict(root_kernel_size=root_kernel_size, root_shortcut=root_shortcut))
|
||||
self.tree1 = DlaTree(
|
||||
levels - 1, block, in_channels, out_channels, stride, root_dim=0, **cargs)
|
||||
self.tree2 = DlaTree(
|
||||
levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels, **cargs)
|
||||
if levels == 1:
|
||||
self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_residual)
|
||||
self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_shortcut)
|
||||
self.level_root = level_root
|
||||
self.root_dim = root_dim
|
||||
self.levels = levels
|
||||
|
||||
def forward(self, x, residual=None, children=None):
|
||||
def forward(self, x, shortcut=None, children=None):
|
||||
children = [] if children is None else children
|
||||
bottom = self.downsample(x)
|
||||
residual = self.project(bottom)
|
||||
shortcut = self.project(bottom)
|
||||
if self.level_root:
|
||||
children.append(bottom)
|
||||
x1 = self.tree1(x, residual)
|
||||
x1 = self.tree1(x, shortcut)
|
||||
if self.levels == 1:
|
||||
x2 = self.tree2(x1)
|
||||
x = self.root(x2, x1, *children)
|
||||
@ -255,7 +255,7 @@ class DlaTree(nn.Module):
|
||||
|
||||
class DLA(nn.Module):
|
||||
def __init__(self, levels, channels, output_stride=32, num_classes=1000, in_chans=3,
|
||||
cardinality=1, base_width=64, block=DlaBottle2neck, residual_root=False,
|
||||
cardinality=1, base_width=64, block=DlaBottle2neck, shortcut_root=False,
|
||||
drop_rate=0.0, global_pool='avg'):
|
||||
super(DLA, self).__init__()
|
||||
self.channels = channels
|
||||
@ -271,7 +271,7 @@ class DLA(nn.Module):
|
||||
nn.ReLU(inplace=True))
|
||||
self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
|
||||
self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2)
|
||||
cargs = dict(cardinality=cardinality, base_width=base_width, root_residual=residual_root)
|
||||
cargs = dict(cardinality=cardinality, base_width=base_width, root_shortcut=shortcut_root)
|
||||
self.level2 = DlaTree(levels[2], block, channels[1], channels[2], 2, level_root=False, **cargs)
|
||||
self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs)
|
||||
self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs)
|
||||
@ -413,7 +413,7 @@ def dla60x(pretrained=False, **kwargs): # DLA-X-60
|
||||
def dla102(pretrained=False, **kwargs): # DLA-102
|
||||
model_kwargs = dict(
|
||||
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
|
||||
block=DlaBottleneck, residual_root=True, **kwargs)
|
||||
block=DlaBottleneck, shortcut_root=True, **kwargs)
|
||||
return _create_dla('dla102', pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@ -421,7 +421,7 @@ def dla102(pretrained=False, **kwargs): # DLA-102
|
||||
def dla102x(pretrained=False, **kwargs): # DLA-X-102
|
||||
model_kwargs = dict(
|
||||
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
|
||||
block=DlaBottleneck, cardinality=32, base_width=4, residual_root=True, **kwargs)
|
||||
block=DlaBottleneck, cardinality=32, base_width=4, shortcut_root=True, **kwargs)
|
||||
return _create_dla('dla102x', pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@ -429,7 +429,7 @@ def dla102x(pretrained=False, **kwargs): # DLA-X-102
|
||||
def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64
|
||||
model_kwargs = dict(
|
||||
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
|
||||
block=DlaBottleneck, cardinality=64, base_width=4, residual_root=True, **kwargs)
|
||||
block=DlaBottleneck, cardinality=64, base_width=4, shortcut_root=True, **kwargs)
|
||||
return _create_dla('dla102x2', pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@ -437,5 +437,5 @@ def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64
|
||||
def dla169(pretrained=False, **kwargs): # DLA-169
|
||||
model_kwargs = dict(
|
||||
levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024],
|
||||
block=DlaBottleneck, residual_root=True, **kwargs)
|
||||
block=DlaBottleneck, shortcut_root=True, **kwargs)
|
||||
return _create_dla('dla169', pretrained, **model_kwargs)
|
||||
|
@ -184,7 +184,7 @@ class DepthwiseSeparableConv(nn.Module):
|
||||
return info
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
shortcut = x
|
||||
|
||||
x = self.conv_dw(x)
|
||||
x = self.bn1(x)
|
||||
@ -200,7 +200,7 @@ class DepthwiseSeparableConv(nn.Module):
|
||||
if self.has_residual:
|
||||
if self.drop_path_rate > 0.:
|
||||
x = drop_path(x, self.drop_path_rate, self.training)
|
||||
x += residual
|
||||
x += shortcut
|
||||
return x
|
||||
|
||||
|
||||
@ -258,7 +258,7 @@ class InvertedResidual(nn.Module):
|
||||
return info
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
shortcut = x
|
||||
|
||||
# Point-wise expansion
|
||||
x = self.conv_pw(x)
|
||||
@ -281,7 +281,7 @@ class InvertedResidual(nn.Module):
|
||||
if self.has_residual:
|
||||
if self.drop_path_rate > 0.:
|
||||
x = drop_path(x, self.drop_path_rate, self.training)
|
||||
x += residual
|
||||
x += shortcut
|
||||
|
||||
return x
|
||||
|
||||
@ -308,7 +308,7 @@ class CondConvResidual(InvertedResidual):
|
||||
self.routing_fn = nn.Linear(in_chs, self.num_experts)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
shortcut = x
|
||||
|
||||
# CondConv routing
|
||||
pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1)
|
||||
@ -335,7 +335,7 @@ class CondConvResidual(InvertedResidual):
|
||||
if self.has_residual:
|
||||
if self.drop_path_rate > 0.:
|
||||
x = drop_path(x, self.drop_path_rate, self.training)
|
||||
x += residual
|
||||
x += shortcut
|
||||
return x
|
||||
|
||||
|
||||
@ -390,7 +390,7 @@ class EdgeResidual(nn.Module):
|
||||
return info
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
shortcut = x
|
||||
|
||||
# Expansion convolution
|
||||
x = self.conv_exp(x)
|
||||
@ -408,6 +408,6 @@ class EdgeResidual(nn.Module):
|
||||
if self.has_residual:
|
||||
if self.drop_path_rate > 0.:
|
||||
x = drop_path(x, self.drop_path_rate, self.training)
|
||||
x += residual
|
||||
x += shortcut
|
||||
|
||||
return x
|
||||
|
@ -112,7 +112,7 @@ class GhostBottleneck(nn.Module):
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
shortcut = x
|
||||
|
||||
# 1st ghost bottleneck
|
||||
x = self.ghost1(x)
|
||||
@ -129,7 +129,7 @@ class GhostBottleneck(nn.Module):
|
||||
# 2nd ghost bottleneck
|
||||
x = self.ghost2(x)
|
||||
|
||||
x += self.shortcut(residual)
|
||||
x += self.shortcut(shortcut)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
from .activations import *
|
||||
from .adaptive_avgmax_pool import \
|
||||
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||
from .anti_aliasing import AntiAliasDownsampleLayer
|
||||
from .blur_pool import BlurPool2d
|
||||
from .classifier import ClassifierHead, create_classifier
|
||||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||
|
@ -1,60 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.parallel
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class AntiAliasDownsampleLayer(nn.Module):
|
||||
def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2, no_jit: bool = False):
|
||||
super(AntiAliasDownsampleLayer, self).__init__()
|
||||
if no_jit:
|
||||
self.op = Downsample(channels, filt_size, stride)
|
||||
else:
|
||||
self.op = DownsampleJIT(channels, filt_size, stride)
|
||||
|
||||
# FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
class DownsampleJIT(object):
|
||||
def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2):
|
||||
self.channels = channels
|
||||
self.stride = stride
|
||||
self.filt_size = filt_size
|
||||
assert self.filt_size == 3
|
||||
assert stride == 2
|
||||
self.filt = {} # lazy init by device for DataParallel compat
|
||||
|
||||
def _create_filter(self, like: torch.Tensor):
|
||||
filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device)
|
||||
filt = filt[:, None] * filt[None, :]
|
||||
filt = filt / torch.sum(filt)
|
||||
return filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
|
||||
|
||||
def __call__(self, input: torch.Tensor):
|
||||
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
|
||||
filt = self.filt.get(str(input.device), self._create_filter(input))
|
||||
return F.conv2d(input_pad, filt, stride=2, padding=0, groups=input.shape[1])
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, channels=None, filt_size=3, stride=2):
|
||||
super(Downsample, self).__init__()
|
||||
self.channels = channels
|
||||
self.filt_size = filt_size
|
||||
self.stride = stride
|
||||
|
||||
assert self.filt_size == 3
|
||||
filt = torch.tensor([1., 2., 1.])
|
||||
filt = filt[:, None] * filt[None, :]
|
||||
filt = filt / torch.sum(filt)
|
||||
|
||||
# self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
|
||||
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
|
||||
|
||||
def forward(self, input):
|
||||
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
|
||||
return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1])
|
@ -3,8 +3,6 @@ BlurPool layer inspired by
|
||||
- Kornia's Max_BlurPool2d
|
||||
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
|
||||
|
||||
FIXME merge this impl with those in `anti_aliasing.py`
|
||||
|
||||
Hacked together by Chris Ha and Ross Wightman
|
||||
"""
|
||||
|
||||
@ -12,7 +10,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from typing import Dict
|
||||
from .padding import get_padding
|
||||
|
||||
|
||||
@ -29,30 +26,17 @@ class BlurPool2d(nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: the transformed tensor.
|
||||
"""
|
||||
filt: Dict[str, torch.Tensor]
|
||||
|
||||
def __init__(self, channels, filt_size=3, stride=2) -> None:
|
||||
super(BlurPool2d, self).__init__()
|
||||
assert filt_size > 1
|
||||
self.channels = channels
|
||||
self.filt_size = filt_size
|
||||
self.stride = stride
|
||||
pad_size = [get_padding(filt_size, stride, dilation=1)] * 4
|
||||
self.padding = nn.ReflectionPad2d(pad_size)
|
||||
self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat
|
||||
self.filt = {} # lazy init by device for DataParallel compat
|
||||
self.padding = [get_padding(filt_size, stride, dilation=1)] * 4
|
||||
coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32))
|
||||
blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1)
|
||||
self.register_buffer('filt', blur_filter, persistent=False)
|
||||
|
||||
def _create_filter(self, like: torch.Tensor):
|
||||
blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device)
|
||||
return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1)
|
||||
|
||||
def _apply(self, fn):
|
||||
# override nn.Module _apply, reset filter cache if used
|
||||
self.filt = {}
|
||||
super(BlurPool2d, self)._apply(fn)
|
||||
|
||||
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
C = input_tensor.shape[1]
|
||||
blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor))
|
||||
return F.conv2d(
|
||||
self.padding(input_tensor), blur_filt, stride=self.stride, groups=C)
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.pad(x, self.padding, 'reflect')
|
||||
return F.conv2d(x, self.filt, stride=self.stride, groups=x.shape[1])
|
||||
|
@ -21,6 +21,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .helpers import to_2tuple
|
||||
from .weight_init import trunc_normal_
|
||||
|
||||
|
||||
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
||||
@ -101,6 +102,11 @@ class BottleneckAttn(nn.Module):
|
||||
|
||||
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
||||
|
||||
def reset_parameters(self):
|
||||
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
|
||||
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
|
||||
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.pos_embed.height and W == self.pos_embed.width
|
||||
|
@ -25,6 +25,8 @@ import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .weight_init import trunc_normal_
|
||||
|
||||
|
||||
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
||||
""" Compute relative logits along one dimension
|
||||
@ -124,6 +126,13 @@ class HaloAttn(nn.Module):
|
||||
self.pos_embed = PosEmbedRel(
|
||||
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale)
|
||||
|
||||
def reset_parameters(self):
|
||||
std = self.q.weight.shape[1] ** -0.5 # fan-in
|
||||
trunc_normal_(self.q.weight, std=std)
|
||||
trunc_normal_(self.kv.weight, std=std)
|
||||
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
|
||||
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert H % self.block_size == 0 and W % self.block_size == 0
|
||||
|
@ -24,6 +24,7 @@ import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .weight_init import trunc_normal_
|
||||
|
||||
|
||||
class LambdaLayer(nn.Module):
|
||||
@ -36,6 +37,7 @@ class LambdaLayer(nn.Module):
|
||||
self,
|
||||
dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_out = dim_out or dim
|
||||
self.dim_k = dim_head # query depth 'k'
|
||||
self.num_heads = num_heads
|
||||
@ -55,6 +57,10 @@ class LambdaLayer(nn.Module):
|
||||
|
||||
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
||||
|
||||
def reset_parameters(self):
|
||||
trunc_normal_(self.qkv.weight, std=self.dim ** -0.5)
|
||||
trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
M = H * W
|
||||
|
@ -107,6 +107,7 @@ class WindowAttention(nn.Module):
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
# 2 * Wh - 1 * 2 * Ww - 1, nH
|
||||
torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads))
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.win_size)
|
||||
@ -120,13 +121,16 @@ class WindowAttention(nn.Module):
|
||||
relative_coords[:, :, 0] *= 2 * self.win_size - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
||||
|
||||
def reset_parameters(self):
|
||||
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
|
292
timm/models/mlp_mixer.py
Normal file
292
timm/models/mlp_mixer.py
Normal file
@ -0,0 +1,292 @@
|
||||
""" MLP-Mixer in PyTorch
|
||||
|
||||
Official JAX impl: https://github.com/google-research/vision_transformer/blob/linen/vit_jax/models_mixer.py
|
||||
|
||||
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
|
||||
@article{tolstikhin2021,
|
||||
title={MLP-Mixer: An all-MLP Architecture for Vision},
|
||||
author={Tolstikhin, Ilya and Houlsby, Neil and Kolesnikov, Alexander and Beyer, Lucas and Zhai, Xiaohua and Unterthiner,
|
||||
Thomas and Yung, Jessica and Keysers, Daniel and Uszkoreit, Jakob and Lucic, Mario and Dosovitskiy, Alexey},
|
||||
journal={arXiv preprint arXiv:2105.01601},
|
||||
year={2021}
|
||||
}
|
||||
|
||||
A thank you to paper authors for releasing code and weights.
|
||||
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
from .layers import DropPath, to_2tuple, lecun_normal_
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
||||
'first_conv': 'stem.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = dict(
|
||||
mixer_s32_224=_cfg(),
|
||||
mixer_s16_224=_cfg(),
|
||||
mixer_b32_224=_cfg(),
|
||||
mixer_b16_224=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth',
|
||||
),
|
||||
mixer_b16_224_in21k=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth',
|
||||
num_classes=21843
|
||||
),
|
||||
mixer_l32_224=_cfg(),
|
||||
mixer_l16_224=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224-92f9adc4.pth',
|
||||
),
|
||||
mixer_l16_224_in21k=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth',
|
||||
num_classes=21843
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
""" MLP Block
|
||||
NOTE: same impl as ViT, move to common location
|
||||
"""
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
NOTE: same impl as ViT, move to common location
|
||||
"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.patch_grid = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.num_patches = self.patch_grid[0] * self.patch_grid[1]
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# FIXME look at relaxing size constraints
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class MixerBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, dim, seq_len, tokens_dim, channels_dim,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.mlp_tokens = Mlp(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp_channels = Mlp(dim, channels_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
|
||||
x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class MlpMixer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes=1000,
|
||||
img_size=224,
|
||||
in_chans=3,
|
||||
patch_size=16,
|
||||
num_blocks=8,
|
||||
hidden_dim=512,
|
||||
tokens_dim=256,
|
||||
channels_dim=2048,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
act_layer=nn.GELU,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
nlhb=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
|
||||
self.stem = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_dim)
|
||||
# FIXME drop_path (stochastic depth scaling rule?)
|
||||
self.blocks = nn.Sequential(*[
|
||||
MixerBlock(
|
||||
hidden_dim, self.stem.num_patches, tokens_dim, channels_dim,
|
||||
norm_layer=norm_layer, act_layer=act_layer, drop=drop, drop_path=drop_path)
|
||||
for _ in range(num_blocks)])
|
||||
self.norm = norm_layer(hidden_dim)
|
||||
self.head = nn.Linear(hidden_dim, self.num_classes) # zero init
|
||||
|
||||
self.init_weights(nlhb=nlhb)
|
||||
|
||||
def init_weights(self, nlhb=False):
|
||||
head_bias = -math.log(self.num_classes) if nlhb else 0.
|
||||
for n, m in self.named_modules():
|
||||
_init_weights(m, n, head_bias=head_bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.blocks(x)
|
||||
x = self.norm(x)
|
||||
x = x.mean(dim=1)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _init_weights(m, n: str, head_bias: float = 0.):
|
||||
""" Mixer weight initialization (trying to match Flax defaults)
|
||||
"""
|
||||
if isinstance(m, nn.Linear):
|
||||
if n.startswith('head'):
|
||||
nn.init.zeros_(m.weight)
|
||||
nn.init.constant_(m.bias, head_bias)
|
||||
else:
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
if 'mlp' in n:
|
||||
nn.init.normal_(m.bias, std=1e-6)
|
||||
else:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
lecun_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.zeros_(m.bias)
|
||||
nn.init.ones_(m.weight)
|
||||
|
||||
|
||||
def _create_mixer(variant, pretrained=False, default_cfg=None, **kwargs):
|
||||
if default_cfg is None:
|
||||
default_cfg = deepcopy(default_cfgs[variant])
|
||||
overlay_external_default_cfg(default_cfg, kwargs)
|
||||
default_num_classes = default_cfg['num_classes']
|
||||
default_img_size = default_cfg['input_size'][-2:]
|
||||
num_classes = kwargs.pop('num_classes', default_num_classes)
|
||||
img_size = kwargs.pop('img_size', default_img_size)
|
||||
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for MLP-Mixer models.')
|
||||
|
||||
model = build_model_with_cfg(
|
||||
MlpMixer, variant, pretrained,
|
||||
default_cfg=default_cfg,
|
||||
img_size=img_size,
|
||||
num_classes=num_classes,
|
||||
**kwargs)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_s32_224(pretrained=False, **kwargs):
|
||||
""" Mixer-S/32 224x224
|
||||
"""
|
||||
model_args = dict(patch_size=32, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=2048, **kwargs)
|
||||
model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_s16_224(pretrained=False, **kwargs):
|
||||
""" Mixer-S/16 224x224
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=2048, **kwargs)
|
||||
model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_b32_224(pretrained=False, **kwargs):
|
||||
""" Mixer-B/32 224x224
|
||||
"""
|
||||
model_args = dict(patch_size=32, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs)
|
||||
model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_b16_224(pretrained=False, **kwargs):
|
||||
""" Mixer-B/16 224x224. ImageNet-1k pretrained weights.
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs)
|
||||
model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_b16_224_in21k(pretrained=False, **kwargs):
|
||||
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights.
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs)
|
||||
model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_l32_224(pretrained=False, **kwargs):
|
||||
""" Mixer-L/32 224x224.
|
||||
"""
|
||||
model_args = dict(patch_size=32, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs)
|
||||
model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_l16_224(pretrained=False, **kwargs):
|
||||
""" Mixer-L/16 224x224. ImageNet-1k pretrained weights.
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs)
|
||||
model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_l16_224_in21k(pretrained=False, **kwargs):
|
||||
""" Mixer-L/16 224x224. ImageNet-21k pretrained weights.
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs)
|
||||
model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args)
|
||||
return model
|
@ -91,7 +91,7 @@ class Bottle2neck(nn.Module):
|
||||
nn.init.zeros_(self.bn3.weight)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
shortcut = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
@ -124,9 +124,9 @@ class Bottle2neck(nn.Module):
|
||||
out = self.se(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
shortcut = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out += shortcut
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
@ -105,7 +105,7 @@ class ResNestBottleneck(nn.Module):
|
||||
nn.init.zeros_(self.bn3.weight)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
shortcut = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
@ -132,9 +132,9 @@ class ResNestBottleneck(nn.Module):
|
||||
out = self.drop_block(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
shortcut = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out += shortcut
|
||||
out = self.act3(out)
|
||||
return out
|
||||
|
||||
|
@ -241,31 +241,31 @@ default_cfgs = {
|
||||
|
||||
# ResNet-RS models
|
||||
'resnetrs50': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs50-7c9728e2.pth',
|
||||
input_size=(3, 160, 160), pool_size=(4, 4), crop_pct=0.91, test_input_size=(3, 224, 224),
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs50_ema-6b53758b.pth',
|
||||
input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.91, test_input_size=(3, 224, 224),
|
||||
interpolation='bicubic', first_conv='conv1.0'),
|
||||
'resnetrs101': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs101-3e4bb55c.pth',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs101_i192_ema-1509bbf6.pth',
|
||||
input_size=(3, 192, 192), pool_size=(6, 6), crop_pct=0.94, test_input_size=(3, 288, 288),
|
||||
interpolation='bicubic', first_conv='conv1.0'),
|
||||
'resnetrs152': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs152-b1efe56d.pth',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs152_i256_ema-a9aff7f9.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
|
||||
interpolation='bicubic', first_conv='conv1.0'),
|
||||
'resnetrs200': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs200-b455b791.pth',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs200_ema-623d2f59.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
|
||||
interpolation='bicubic', first_conv='conv1.0'),
|
||||
'resnetrs270': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs270-cafcfbc7.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs270_ema-b40e674c.pth',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 352, 352),
|
||||
interpolation='bicubic', first_conv='conv1.0'),
|
||||
'resnetrs350': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs350-06d9bfac.pth',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs350_i256_ema-5a1aa8f1.pth',
|
||||
input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, test_input_size=(3, 384, 384),
|
||||
interpolation='bicubic', first_conv='conv1.0'),
|
||||
'resnetrs420': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs420-d26764a5.pth',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs420_ema-972dee69.pth',
|
||||
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, test_input_size=(3, 416, 416),
|
||||
interpolation='bicubic', first_conv='conv1.0'),
|
||||
}
|
||||
@ -315,7 +315,7 @@ class BasicBlock(nn.Module):
|
||||
nn.init.zeros_(self.bn2.weight)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
shortcut = x
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
@ -337,8 +337,8 @@ class BasicBlock(nn.Module):
|
||||
x = self.drop_path(x)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(residual)
|
||||
x += residual
|
||||
shortcut = self.downsample(shortcut)
|
||||
x += shortcut
|
||||
x = self.act2(x)
|
||||
|
||||
return x
|
||||
@ -385,7 +385,7 @@ class Bottleneck(nn.Module):
|
||||
nn.init.zeros_(self.bn3.weight)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
shortcut = x
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
@ -413,8 +413,8 @@ class Bottleneck(nn.Module):
|
||||
x = self.drop_path(x)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(residual)
|
||||
x += residual
|
||||
shortcut = self.downsample(shortcut)
|
||||
x += shortcut
|
||||
x = self.act3(x)
|
||||
|
||||
return x
|
||||
|
@ -92,7 +92,7 @@ class Bottleneck(nn.Module):
|
||||
"""
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
shortcut = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
@ -106,9 +106,9 @@ class Bottleneck(nn.Module):
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
shortcut = self.downsample(x)
|
||||
|
||||
out = self.se_module(out) + residual
|
||||
out = self.se_module(out) + shortcut
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
@ -204,7 +204,7 @@ class SEResNetBlock(nn.Module):
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
shortcut = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
@ -215,9 +215,9 @@ class SEResNetBlock(nn.Module):
|
||||
out = self.relu(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
shortcut = self.downsample(x)
|
||||
|
||||
out = self.se_module(out) + residual
|
||||
out = self.se_module(out) + shortcut
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
@ -76,7 +76,7 @@ class SelectiveKernelBasic(nn.Module):
|
||||
nn.init.zeros_(self.conv2.bn.weight)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
shortcut = x
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
if self.se is not None:
|
||||
@ -84,8 +84,8 @@ class SelectiveKernelBasic(nn.Module):
|
||||
if self.drop_path is not None:
|
||||
x = self.drop_path(x)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(residual)
|
||||
x += residual
|
||||
shortcut = self.downsample(shortcut)
|
||||
x += shortcut
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
@ -124,7 +124,7 @@ class SelectiveKernelBottleneck(nn.Module):
|
||||
nn.init.zeros_(self.conv3.bn.weight)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
shortcut = x
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
@ -133,8 +133,8 @@ class SelectiveKernelBottleneck(nn.Module):
|
||||
if self.drop_path is not None:
|
||||
x = self.drop_path(x)
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(residual)
|
||||
x += residual
|
||||
shortcut = self.downsample(shortcut)
|
||||
x += shortcut
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
@ -5,16 +5,13 @@ https://arxiv.org/pdf/2003.13630.pdf
|
||||
Original model: https://github.com/mrT23/TResNet
|
||||
|
||||
"""
|
||||
import copy
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, InplaceAbn, ClassifierHead, SEModule
|
||||
from .layers import SpaceToDepthModule, BlurPool2d, InplaceAbn, ClassifierHead, SEModule
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl']
|
||||
@ -92,9 +89,9 @@ class BasicBlock(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
shortcut = self.downsample(x)
|
||||
else:
|
||||
residual = x
|
||||
shortcut = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
@ -102,7 +99,7 @@ class BasicBlock(nn.Module):
|
||||
if self.se is not None:
|
||||
out = self.se(out)
|
||||
|
||||
out += residual
|
||||
out += shortcut
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
@ -139,9 +136,9 @@ class Bottleneck(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
shortcut = self.downsample(x)
|
||||
else:
|
||||
residual = x
|
||||
shortcut = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
@ -149,22 +146,19 @@ class Bottleneck(nn.Module):
|
||||
out = self.se(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = out + residual # no inplace
|
||||
out = out + shortcut # no inplace
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class TResNet(nn.Module):
|
||||
def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False,
|
||||
global_pool='fast', drop_rate=0.):
|
||||
def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, global_pool='fast', drop_rate=0.):
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
super(TResNet, self).__init__()
|
||||
|
||||
# JIT layers
|
||||
space_to_depth = SpaceToDepthModule()
|
||||
aa_layer = partial(AntiAliasDownsampleLayer, no_jit=no_aa_jit)
|
||||
aa_layer = BlurPool2d
|
||||
|
||||
# TResnet stages
|
||||
self.inplanes = int(64 * width_factor)
|
||||
@ -181,7 +175,7 @@ class TResNet(nn.Module):
|
||||
|
||||
# body
|
||||
self.body = nn.Sequential(OrderedDict([
|
||||
('SpaceToDepth', space_to_depth),
|
||||
('SpaceToDepth', SpaceToDepthModule()),
|
||||
('conv1', conv1),
|
||||
('layer1', layer1),
|
||||
('layer2', layer2),
|
||||
|
Loading…
x
Reference in New Issue
Block a user