features/edgevit3 (#214)

* add backbone model 'EdgeVit'
pull/236/head
Jiabei-prog 2022-11-17 14:30:12 +08:00 committed by GitHub
parent 9809c3b184
commit 17c1f39b6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 700 additions and 1 deletions

View File

@ -0,0 +1,111 @@
_base_ = '../../../base.py'
log_config = dict(
interval=10,
hooks=[dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')])
# model settings
model = dict(
type='Classification',
backbone=dict(
type='EdgeVit',
depth=[1, 1, 3, 2],
embed_dim=[36, 72, 144, 288],
head_dim=36,
mlp_ratio=[4] * 4,
qkv_bias=True,
num_classes=1000,
drop_path_rate=0.1,
sr_ratios=[4, 2, 2, 1]),
head=dict(
type='ClsHead',
with_avg_pool=True,
in_channels=288,
loss_config=dict(
type='CrossEntropyLossWithLabelSmooth', label_smooth=0),
))
data_train_list = 'data/imagenet_raw/meta/train_labeled.txt'
data_train_root = 'data/imagenet_raw/train/'
data_test_list = 'data/imagenet_raw/meta/val_labeled.txt'
data_test_root = 'data/imagenet_raw/validation/'
dataset_type = 'ClsDataset'
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_pipeline = [
dict(
type='MAEFtAugment',
input_size=224,
color_jitter=0.4,
auto_augment='rand-m9-mstd0.5-inc1',
interpolation='bicubic',
re_prob=0.25,
re_mode='pixel',
re_count=1,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
is_train=True),
]
test_pipeline = [
dict(
type='MAEFtAugment',
input_size=224,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
is_train=False,
),
]
data = dict(
imgs_per_gpu=512,
workers_per_gpu=10,
use_repeated_augment_sampler=True,
train=dict(
type=dataset_type,
data_source=dict(
list_file=data_train_list,
root=data_train_root,
type='ClsSourceImageList'),
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_source=dict(
list_file=data_test_list,
root=data_test_root,
type='ClsSourceImageList'),
pipeline=test_pipeline))
eval_config = dict(initial=True, interval=1, gpu_collect=True)
eval_pipelines = [
dict(
mode='test',
data=data['val'],
dist_eval=True,
evaluators=[dict(type='ClsEvaluator', topk=(1, 5))],
)
]
# additional hooks
custom_hooks = []
# optimizer
optimizer = dict(type='AdamW', lr=2e-3, weight_decay=0.05)
# learning policy
lr_config = dict(
policy='CosineAnnealingWarmupByEpoch',
min_lr=1e-5,
warmup='linear',
warmup_iters=5,
warmup_ratio=1e-6,
# warmup_lr=1e-6,
warmup_by_epoch=True,
by_epoch=True)
checkpoint_config = dict(interval=10)
# runtime settings
total_epochs = 300
ema = dict(decay=0.99996)

View File

@ -0,0 +1,42 @@
_base_ = './EdgeVit_b512x8_300e_jpg.py'
# model settings
model = dict(
type='Classification',
train_preprocess=['mixUp'],
mixup_cfg=dict(
mixup_alpha=0.8,
cutmix_alpha=1.0,
cutmix_minmax=None,
prob=1.0,
switch_prob=0.5,
mode='batch',
label_smoothing=0.1,
num_classes=1000),
backbone=dict(
type='EdgeVit',
depth=[1, 2, 5, 3],
embed_dim=[48, 96, 240, 384],
head_dim=48,
mlp_ratio=[4] * 4,
qkv_bias=True,
num_classes=1000,
drop_path_rate=0.1,
sr_ratios=[4, 2, 2, 1]),
head=dict(
type='ClsHead',
with_avg_pool=True,
in_channels=384,
loss_config={
'type': 'SoftTargetCrossEntropy',
},
with_fc=True))
data = dict(
imgs_per_gpu=128,
workers_per_gpu=10,
use_repeated_augment_sampler=True,
)
# optimizer
update_interval = 8
optimizer_config = dict(update_interval=update_interval)

View File

@ -0,0 +1,31 @@
_base_ = './EdgeVit_b512x8_300e_jpg.py'
model = dict(
type='Classification',
backbone=dict(
type='EdgeVit',
depth=[1, 1, 3, 1],
embed_dim=[48, 96, 240, 384],
head_dim=48,
mlp_ratio=[4] * 4,
qkv_bias=True,
num_classes=1000,
drop_path_rate=0.1,
sr_ratios=[4, 2, 2, 1]),
head=dict(
type='ClsHead',
with_avg_pool=True,
in_channels=384,
loss_config=dict(
type='CrossEntropyLossWithLabelSmooth', label_smooth=0.1),
))
# input data settings
data = dict(
imgs_per_gpu=256,
workers_per_gpu=10,
use_repeated_augment_sampler=True,
)
# optimizer
update_interval = 4
optimizer_config = dict(update_interval=update_interval)

View File

@ -0,0 +1,32 @@
_base_ = './EdgeVit_b512x8_300e_jpg.py'
# model settings
model = dict(
type='Classification',
backbone=dict(
type='EdgeVit',
depth=[1, 1, 3, 2],
embed_dim=[36, 72, 144, 288],
head_dim=36,
mlp_ratio=[4] * 4,
qkv_bias=True,
num_classes=1000,
drop_path_rate=0.1,
sr_ratios=[4, 2, 2, 1]),
head=dict(
type='ClsHead',
with_avg_pool=True,
in_channels=288,
loss_config=dict(
type='CrossEntropyLossWithLabelSmooth', label_smooth=0.1),
))
# input data settings
data = dict(
imgs_per_gpu=512,
workers_per_gpu=10,
use_repeated_augment_sampler=True,
)
# optimizer
update_interval = 2
optimizer_config = dict(update_interval=update_interval)

View File

@ -82,5 +82,8 @@
| efficientformer_l1 | [efficientformer_l1](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/efficientformer/efficientformer_l1.py) | 80.102 | 94.934 | 1820 | 7.5 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/efficientformer/efficientformer_l1_1000d.pth) |
| efficientformer_l3 | [efficientformer_l3](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/efficientformer/efficientformer_l3.py) | 82.272 | 96.028 | 2436 | 13.07 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/efficientformer/efficientformer_l3_300d.pth) |
| efficientformer_l7 | [efficientformer_l7](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/efficientformer/efficientformer_l7.py) | 83.076 | 96.44 | 1622 | 18.96 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/efficientformer/efficientformer_l7_300d.pth) |
| EdgeVit_xxs_b512_224 | [EdgeVit_xxs_b512_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/edgevit/imagenet_edgeVIT_xxs_jpg.py) | 75.18 | 92.188 | 206 | 8.67 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/edgevit/edgexxs/ClsEvaluator_neck_top1_best.pth) |
| EdgeVit_xs_b256_224 | [EdgeVit_xs_b256_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/edgevit/imagenet_edgeVIT_xs_jpg.py) | 77.624 | 93.47 | 551 | 8.04 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/edgevit/edgexs/ClsEvaluator_neck_top1_best.pth) |
| EdgeVit_s_b128_224 | [EdgeVit_s_b128_224](https://github.com/alibaba/EasyCV/tree/master/configs/classification/imagenet/edgevit/imagenet_edgeVIT_s_jpg.py) | 80.3 | 95.302 | 576 | 13.49 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/classification/edgevit/edges/ClsEvaluator_neck_top1_best.pth) |
(ps: 通过导入官方模型得到推理结果需要torch.__version__ >= 1.9.0推理的输入尺寸默认为224机器默认为V100 16G其中gpu memory记录的是gpu peak memory)

View File

@ -3,6 +3,7 @@ from .benchmark_mlp import BenchMarkMLP
from .bninception import BNInception
from .conv_mae_vit import FastConvMAEViT
from .conv_vitdet import ConvViTDet
from .edgevit import EdgeVit
from .efficientformer import EfficientFormer
from .face_keypoint_backbone import FaceKeypointBackbone
from .genet import PlainNet

View File

@ -0,0 +1,418 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
"""
This model is taken from
https://github.com/SamsungLabs/EdgeViTs
"""
from collections import OrderedDict
from functools import partial
import torch
import torch.nn as nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from easycv.models.utils import ConvMlp, Mlp
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.logger import get_root_logger
from ..registry import BACKBONES
class GlobalSparseAttn(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
sr_ratio=1):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
# self.upsample = nn.Upsample(scale_factor=sr_ratio, mode='nearest')
self.sr = sr_ratio
if self.sr > 1:
self.sampler = nn.AvgPool2d(1, sr_ratio)
kernel_size = sr_ratio
self.LocalProp = nn.ConvTranspose2d(
dim, dim, kernel_size, stride=sr_ratio, groups=dim)
self.norm = nn.LayerNorm(dim)
else:
self.sampler = nn.Identity()
self.upsample = nn.Identity()
self.norm = nn.Identity()
def forward(self, x, H: int, W: int):
B, N, C = x.shape
if self.sr > 1.:
x = x.transpose(1, 2).reshape(B, C, H, W)
x = self.sampler(x)
x = x.flatten(2).transpose(1, 2)
qkv = self.qkv(x).reshape(B, -1, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, -1, C)
if self.sr > 1:
x = x.permute(0, 2, 1).reshape(B, C, int(H / self.sr),
int(W / self.sr))
x = self.LocalProp(x)
x = x.reshape(B, C, -1).permute(0, 2, 1)
x = self.norm(x)
x = self.proj(x)
x = self.proj_drop(x)
return x
class LocalAgg(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super().__init__()
self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
self.norm1 = nn.BatchNorm2d(dim)
self.conv1 = nn.Conv2d(dim, dim, 1)
self.conv2 = nn.Conv2d(dim, dim, 1)
self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = nn.BatchNorm2d(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ConvMlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def forward(self, x):
x = x + self.pos_embed(x)
x = x + self.drop_path(
self.conv2(self.attn(self.conv1(self.norm1(x)))))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class SelfAttn(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
sr_ratio=1.):
super().__init__()
self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
self.norm1 = norm_layer(dim)
self.attn = GlobalSparseAttn(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
sr_ratio=sr_ratio)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
# global layer_scale
# self.ls = layer_scale
def forward(self, x):
x = x + self.pos_embed(x)
B, N, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x)))
x = x.transpose(1, 2).reshape(B, N, H, W)
return x
class LGLBlock(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
sr_ratio=1.):
super().__init__()
if sr_ratio > 1:
self.LocalAgg = LocalAgg(dim, num_heads, mlp_ratio, qkv_bias,
qk_scale, drop, attn_drop, drop_path,
act_layer, norm_layer)
else:
self.LocalAgg = nn.Identity()
self.SelfAttn = SelfAttn(dim, num_heads, mlp_ratio, qkv_bias, qk_scale,
drop, attn_drop, drop_path, act_layer,
norm_layer, sr_ratio)
def forward(self, x):
x = self.LocalAgg(x)
x = self.SelfAttn(x)
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (
img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.norm = nn.LayerNorm(embed_dim)
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({B}*{C}*{H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
B, C, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
return x
@BACKBONES.register_module()
class EdgeVit(nn.Module):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
https://arxiv.org/abs/2010.11929
"""
def __init__(self,
depth=[1, 2, 3, 2],
img_size=224,
in_chans=3,
num_classes=1000,
embed_dim=[48, 96, 240, 384],
head_dim=48,
mlp_ratio=[4] * 4,
qkv_bias=True,
qk_scale=None,
representation_size=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-8),
sr_ratios=[4, 2, 2, 1],
pretrained=None):
"""
Args:
depth (list): depth of each stage
img_size (int, tuple): input image size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (list): embedding dimension of each stage
head_dim (int): head dimension
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
norm_layer (nn.Module): normalization layer
"""
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.patch_embed1 = PatchEmbed(
img_size=img_size,
patch_size=4,
in_chans=in_chans,
embed_dim=embed_dim[0])
self.patch_embed2 = PatchEmbed(
img_size=img_size // 4,
patch_size=2,
in_chans=embed_dim[0],
embed_dim=embed_dim[1])
self.patch_embed3 = PatchEmbed(
img_size=img_size // 8,
patch_size=2,
in_chans=embed_dim[1],
embed_dim=embed_dim[2])
self.patch_embed4 = PatchEmbed(
img_size=img_size // 16,
patch_size=2,
in_chans=embed_dim[2],
embed_dim=embed_dim[3])
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))
] # stochastic depth decay rule
num_heads = [dim // head_dim for dim in embed_dim]
self.blocks1 = nn.ModuleList([
LGLBlock(
dim=embed_dim[0],
num_heads=num_heads[0],
mlp_ratio=mlp_ratio[0],
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
sr_ratio=sr_ratios[0]) for i in range(depth[0])
])
self.blocks2 = nn.ModuleList([
LGLBlock(
dim=embed_dim[1],
num_heads=num_heads[1],
mlp_ratio=mlp_ratio[1],
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i + depth[0]],
norm_layer=norm_layer,
sr_ratio=sr_ratios[1]) for i in range(depth[1])
])
self.blocks3 = nn.ModuleList([
LGLBlock(
dim=embed_dim[2],
num_heads=num_heads[2],
mlp_ratio=mlp_ratio[2],
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i + depth[0] + depth[1]],
norm_layer=norm_layer,
sr_ratio=sr_ratios[2]) for i in range(depth[2])
])
self.blocks4 = nn.ModuleList([
LGLBlock(
dim=embed_dim[3],
num_heads=num_heads[3],
mlp_ratio=mlp_ratio[3],
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i + depth[0] + depth[1] + depth[2]],
norm_layer=norm_layer,
sr_ratio=sr_ratios[3]) for i in range(depth[3])
])
self.norm = nn.BatchNorm2d(embed_dim[-1])
# Representation layer
if representation_size:
self.num_features = representation_size
self.pre_logits = nn.Sequential(
OrderedDict([('fc', nn.Linear(embed_dim, representation_size)),
('act', nn.Tanh())]))
else:
self.pre_logits = nn.Identity()
self.pretrained = pretrained
self.init_weights()
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
pretrained = pretrained or self.pretrained
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
if isinstance(pretrained, str):
self.apply(_init_weights)
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
self.apply(_init_weights)
else:
raise TypeError('pretrained must be a str or None')
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward_features(self, x):
x = self.patch_embed1(x)
x = self.pos_drop(x)
for blk in self.blocks1:
x = blk(x)
x = self.patch_embed2(x)
for blk in self.blocks2:
x = blk(x)
x = self.patch_embed3(x)
for blk in self.blocks3:
x = blk(x)
x = self.patch_embed4(x)
for blk in self.blocks4:
x = blk(x)
x = self.norm(x)
x = self.pre_logits(x)
return x
def forward(self, x):
x = self.forward_features(x)
return [x]

View File

@ -18,7 +18,7 @@ from .scale import Scale
# from .weight_init import (bias_init_with_prob, kaiming_init, normal_init,
# uniform_init, xavier_init)
from .sobel import Sobel
from .transformer import (MLP, DropPath, Mlp, TransformerEncoder,
from .transformer import (MLP, ConvMlp, DropPath, Mlp, TransformerEncoder,
TransformerEncoderLayer, _get_activation_fn,
_get_clones)

View File

@ -66,6 +66,31 @@ class Mlp(nn.Module):
return x
class ConvMlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
return x

View File

@ -0,0 +1,36 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import torch
from easycv.models.backbones import EdgeVit
class EdgeVitTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_vitdet(self):
model = EdgeVit(
img_size=224,
depth=[1, 1, 3, 2],
embed_dim=[36, 72, 144, 288],
head_dim=36,
mlp_ratio=[4] * 4,
qkv_bias=True,
num_classes=1000,
drop_path_rate=0.1,
sr_ratios=[4, 2, 2, 1],
)
model.init_weights()
model.train()
imgs = torch.rand(36, 3, 224, 224)
feat = model(imgs)
self.assertEqual(len(feat), 1)
self.assertEqual(feat[0].shape, torch.Size([36, 288, 7, 7]))
if __name__ == '__main__':
unittest.main()