[Enhancement] Improve the accuracy of ResNet (#572)
* add itp timm * minor update * minor update * minor update * add rep aug, minor update on configs * minor update * add target threshold * add decaymulti * minor update * minor update * add lbl smooth * update lr * reorganize config files and code * minor bugfixes * remove unused parts and minor fixes on cfg * critical bugfix, add test and cfg update * refactor code * update doc string * remove duplicate code * refactor drop path in resnet * rename * Modify configs and add README&metafile * Update metafile Co-authored-by: mzr1996 <mzr1996@163.com>pull/623/head
parent
e7c06b8541
commit
0bbbb04429
|
@ -0,0 +1,53 @@
|
|||
_base_ = ['./pipelines/rand_aug.py']
|
||||
|
||||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='RandomResizedCrop', size=224),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies={{_base_.rand_increasing_policies}},
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=7,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
|
||||
interpolation='bicubic')),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', size=(236, -1)),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=256,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/train',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/val',
|
||||
ann_file='data/imagenet/meta/val.txt',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
# replace `data/val` with `data/test` for standard test
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/val',
|
||||
ann_file='data/imagenet/meta/val.txt',
|
||||
pipeline=test_pipeline))
|
||||
|
||||
evaluation = dict(interval=1, metric='accuracy')
|
|
@ -0,0 +1,53 @@
|
|||
_base_ = ['./pipelines/rand_aug.py']
|
||||
|
||||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='RandomResizedCrop', size=160),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies={{_base_.rand_increasing_policies}},
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=6,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
|
||||
interpolation='bicubic')),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', size=(236, -1)),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=256,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/train',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/val',
|
||||
ann_file='data/imagenet/meta/val.txt',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
# replace `data/val` with `data/test` for standard test
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/val',
|
||||
ann_file='data/imagenet/meta/val.txt',
|
||||
pipeline=test_pipeline))
|
||||
|
||||
evaluation = dict(interval=1, metric='accuracy')
|
|
@ -0,0 +1,12 @@
|
|||
# optimizer
|
||||
optimizer = dict(type='Lamb', lr=0.005, weight_decay=0.02)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
policy='CosineAnnealing',
|
||||
min_lr=1.0e-6,
|
||||
warmup='linear',
|
||||
# For ImageNet-1k, 626 iters per epoch, warmup 5 epochs.
|
||||
warmup_iters=5 * 626,
|
||||
warmup_ratio=0.0001)
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=100)
|
|
@ -55,3 +55,8 @@ The depth of representations is of central importance for many visual recognitio
|
|||
| ResNetV1D-101 | 44.57 | 8.09 | 78.93 | 94.48 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_b32x8_imagenet_20210531-6e13bcd3.log.json) |
|
||||
| ResNetV1D-152 | 60.21 | 11.82 | 79.41 | 94.70 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnetv1d152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_b32x8_imagenet_20210531-278cf22a.log.json) |
|
||||
| ResNet-50 (fp16) | 25.56 | 4.12 | 76.30 | 93.07 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32-fp16_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/fp16/resnet50_batch256_fp16_imagenet_20210320-b3964210.pth) | [log](https://download.openmmlab.com/mmclassification/v0/fp16/resnet50_batch256_fp16_imagenet_20210320-b3964210.log.json) |
|
||||
| ResNet-50 (rsb-a1) | 25.56 | 4.12 | 80.12 | 94.78 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb256-rsb-a1-600e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a1-600e_in1k_20211228-20e21305.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a1-600e_in1k_20211228-20e21305.log.json) |
|
||||
| ResNet-50 (rsb-a2) | 25.56 | 4.12 | 79.55 | 94.37 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb256-rsb-a2-300e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a2-300e_in1k_20211228-0fd8be6e.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a2-300e_in1k_20211228-0fd8be6e.log.json) |
|
||||
| ResNet-50 (rsb-a3) | 25.56 | 4.12 | 78.30 | 93.80 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a3-100e_in1k_20211228-3493673c.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a3-100e_in1k_20211228-3493673c.log.json) |
|
||||
|
||||
*The "rsb" means using the training settings from [ResNet strikes back: An improved training procedure in timm](https://arxiv.org/abs/2110.00476).*
|
||||
|
|
|
@ -232,3 +232,70 @@ Models:
|
|||
Top 5 Accuracy: 93.07
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/fp16/resnet50_batch256_fp16_imagenet_20210320-b3964210.pth
|
||||
Config: configs/resnet/resnet50_8xb32-fp16_in1k.py
|
||||
- Name: resnet50_8xb256-rsb-a1-600e_in1k
|
||||
Metadata:
|
||||
FLOPs: 4120000000
|
||||
Parameters: 25560000
|
||||
Training Techniques:
|
||||
- LAMB
|
||||
- Weight Decay
|
||||
- Cosine Annealing
|
||||
- Mixup
|
||||
- CutMix
|
||||
- RepeatAugSampler
|
||||
- RandAugment
|
||||
Epochs: 600
|
||||
Batch Size: 2048
|
||||
In Collection: ResNet
|
||||
Results:
|
||||
- Task: Image Classification
|
||||
Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 80.12
|
||||
Top 5 Accuracy: 94.78
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a1-600e_in1k_20211228-20e21305.pth
|
||||
Config: configs/resnet/resnet50_8xb256-rsb-a1-600e_in1k.py
|
||||
- Name: resnet50_8xb256-rsb-a2-300e_in1k
|
||||
Metadata:
|
||||
FLOPs: 4120000000
|
||||
Parameters: 25560000
|
||||
Training Techniques:
|
||||
- LAMB
|
||||
- Weight Decay
|
||||
- Cosine Annealing
|
||||
- Mixup
|
||||
- CutMix
|
||||
- RepeatAugSampler
|
||||
- RandAugment
|
||||
Epochs: 300
|
||||
Batch Size: 2048
|
||||
In Collection: ResNet
|
||||
Results:
|
||||
- Task: Image Classification
|
||||
Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 79.55
|
||||
Top 5 Accuracy: 94.37
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a2-300e_in1k_20211228-0fd8be6e.pth
|
||||
Config: configs/resnet/resnet50_8xb256-rsb-a2-300e_in1k.py
|
||||
- Name: resnet50_8xb256-rsb-a3-100e_in1k
|
||||
Metadata:
|
||||
FLOPs: 4120000000
|
||||
Parameters: 25560000
|
||||
Training Techniques:
|
||||
- LAMB
|
||||
- Weight Decay
|
||||
- Cosine Annealing
|
||||
- Mixup
|
||||
- CutMix
|
||||
- RandAugment
|
||||
Batch Size: 2048
|
||||
In Collection: ResNet
|
||||
Results:
|
||||
- Task: Image Classification
|
||||
Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 78.30
|
||||
Top 5 Accuracy: 93.80
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a3-100e_in1k_20211228-3493673c.pth
|
||||
Config: configs/resnet/resnet50_8xb256-rsb-a3-100e_in1k.py
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
_base_ = [
|
||||
'../_base_/models/resnet50.py',
|
||||
'../_base_/datasets/imagenet_bs256_rsb_a12.py',
|
||||
'../_base_/schedules/imagenet_bs2048_rsb.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# Model settings
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
||||
drop_path_rate=0.05,
|
||||
),
|
||||
head=dict(loss=dict(use_sigmoid=True)),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.2, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
||||
|
||||
# Dataset settings
|
||||
sampler = dict(type='RepeatAugSampler')
|
||||
|
||||
# Schedule settings
|
||||
runner = dict(max_epochs=600)
|
||||
optimizer = dict(
|
||||
weight_decay=0.01,
|
||||
paramwise_cfg=dict(bias_decay_mult=0., norm_decay_mult=0.),
|
||||
)
|
|
@ -0,0 +1,25 @@
|
|||
_base_ = [
|
||||
'../_base_/models/resnet50.py',
|
||||
'../_base_/datasets/imagenet_bs256_rsb_a12.py',
|
||||
'../_base_/schedules/imagenet_bs2048_rsb.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# Model settings
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
||||
drop_path_rate=0.05,
|
||||
),
|
||||
head=dict(loss=dict(use_sigmoid=True)),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.1, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
||||
|
||||
# Dataset settings
|
||||
sampler = dict(type='RepeatAugSampler')
|
||||
|
||||
# Schedule settings
|
||||
runner = dict(max_epochs=300)
|
||||
optimizer = dict(paramwise_cfg=dict(bias_decay_mult=0., norm_decay_mult=0.))
|
|
@ -0,0 +1,19 @@
|
|||
_base_ = [
|
||||
'../_base_/models/resnet50.py',
|
||||
'../_base_/datasets/imagenet_bs256_rsb_a3.py',
|
||||
'../_base_/schedules/imagenet_bs2048_rsb.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# Model settings
|
||||
model = dict(
|
||||
backbone=dict(norm_cfg=dict(type='SyncBN', requires_grad=True)),
|
||||
head=dict(loss=dict(use_sigmoid=True)),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.1, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
||||
|
||||
# Schedule settings
|
||||
optimizer = dict(
|
||||
lr=0.008, paramwise_cfg=dict(bias_decay_mult=0., norm_decay_mult=0.))
|
|
@ -29,7 +29,7 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
|
|||
| RepVGG-D2se\* | 133.33 (train) | 120.39 (deploy) | 36.56 (train) | 32.85 (deploy) | 81.81 | 95.94 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-D2se_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) | [config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-D2se_deploy_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-D2se_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-cf3139b7.pth) |
|
||||
| ResNet-18 | 11.69 | 1.82 | 70.07 | 89.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_batch256_imagenet_20200708-34ab8f90.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_batch256_imagenet_20200708-34ab8f90.log.json) |
|
||||
| ResNet-34 | 21.8 | 3.68 | 73.85 | 91.53 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_batch256_imagenet_20200708-32ffb4f7.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_batch256_imagenet_20200708-32ffb4f7.log.json) |
|
||||
| ResNet-50 | 25.56 | 4.12 | 76.55 | 93.15 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.log.json) |
|
||||
| ResNet-50 (rsb-a1) | 25.56 | 4.12 | 80.12 | 94.78 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb256-rsb-a1-600e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a1-600e_in1k_20211228-20e21305.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb256-rsb-a1-600e_in1k_20211228-20e21305.log.json) |
|
||||
| ResNet-101 | 44.55 | 7.85 | 78.18 | 94.03 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet101_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_batch256_imagenet_20200708-753f3608.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_batch256_imagenet_20200708-753f3608.log.json) |
|
||||
| ResNet-152 | 60.19 | 11.58 | 78.63 | 94.16 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet152_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_batch256_imagenet_20200708-ec25b1f9.pth) | [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_batch256_imagenet_20200708-ec25b1f9.log.json) |
|
||||
| Res2Net-50-14w-8s\* | 25.06 | 4.22 | 78.14 | 93.85 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/res2net/res2net50-w14-s8_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/res2net/res2net50-w14-s8_3rdparty_8xb32_in1k_20210927-bc967bf1.pth) |
|
||||
|
|
|
@ -1,13 +1,17 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
|
||||
constant_init)
|
||||
from mmcv.cnn.bricks import DropPath
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
eps = 1.0e-5
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
"""BasicBlock for ResNet.
|
||||
|
@ -42,7 +46,8 @@ class BasicBlock(nn.Module):
|
|||
style='pytorch',
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN')):
|
||||
norm_cfg=dict(type='BN'),
|
||||
drop_path_rate=0.0):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
@ -83,6 +88,8 @@ class BasicBlock(nn.Module):
|
|||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.drop_path = DropPath(drop_prob=drop_path_rate
|
||||
) if drop_path_rate > eps else nn.Identity()
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
|
@ -107,6 +114,8 @@ class BasicBlock(nn.Module):
|
|||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out = self.drop_path(out)
|
||||
|
||||
out += identity
|
||||
|
||||
return out
|
||||
|
@ -154,7 +163,8 @@ class Bottleneck(nn.Module):
|
|||
style='pytorch',
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN')):
|
||||
norm_cfg=dict(type='BN'),
|
||||
drop_path_rate=0.0):
|
||||
super(Bottleneck, self).__init__()
|
||||
assert style in ['pytorch', 'caffe']
|
||||
|
||||
|
@ -213,6 +223,8 @@ class Bottleneck(nn.Module):
|
|||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.drop_path = DropPath(drop_prob=drop_path_rate
|
||||
) if drop_path_rate > eps else nn.Identity()
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
|
@ -245,6 +257,8 @@ class Bottleneck(nn.Module):
|
|||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out = self.drop_path(out)
|
||||
|
||||
out += identity
|
||||
|
||||
return out
|
||||
|
@ -465,7 +479,8 @@ class ResNet(BaseBackbone):
|
|||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]):
|
||||
],
|
||||
drop_path_rate=0.0):
|
||||
super(ResNet, self).__init__(init_cfg)
|
||||
if depth not in self.arch_settings:
|
||||
raise KeyError(f'invalid depth {depth} for resnet')
|
||||
|
@ -512,7 +527,8 @@ class ResNet(BaseBackbone):
|
|||
avg_down=self.avg_down,
|
||||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg)
|
||||
norm_cfg=norm_cfg,
|
||||
drop_path_rate=drop_path_rate)
|
||||
_in_channels = _out_channels
|
||||
_out_channels *= 2
|
||||
layer_name = f'layer{i + 1}'
|
||||
|
|
|
@ -4,4 +4,4 @@ from .cutmix import BatchCutMixLayer
|
|||
from .identity import Identity
|
||||
from .mixup import BatchMixupLayer
|
||||
|
||||
__all__ = ['Augments', 'BatchCutMixLayer', 'Identity', 'BatchMixupLayer']
|
||||
__all__ = ('Augments', 'BatchCutMixLayer', 'Identity', 'BatchMixupLayer')
|
||||
|
|
|
@ -456,6 +456,19 @@ def test_resnet():
|
|||
assert feat[2].shape == (1, 1024, 14, 14)
|
||||
assert feat[3].shape == (1, 2048, 7, 7)
|
||||
|
||||
# Test ResNet50 with DropPath forward
|
||||
model = ResNet(50, out_indices=(0, 1, 2, 3), drop_path_rate=0.5)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == (1, 256, 56, 56)
|
||||
assert feat[1].shape == (1, 512, 28, 28)
|
||||
assert feat[2].shape == (1, 1024, 14, 14)
|
||||
assert feat[3].shape == (1, 2048, 7, 7)
|
||||
|
||||
# Test ResNet50 with layers 1, 2, 3 out forward
|
||||
model = ResNet(50, out_indices=(0, 1, 2))
|
||||
model.init_weights()
|
||||
|
|
Loading…
Reference in New Issue