Merge pull request #1780 from timerring/dev
[CodeCamp2023-338] New Version of config Adapting Swin Transformer Algorithmpull/1853/head
commit
5c71de6b8e
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.dataset import DefaultSampler
|
||||
|
||||
from mmpretrain.datasets import (CUB, CenterCrop, LoadImageFromFile,
|
||||
PackInputs, RandomCrop, RandomFlip, Resize)
|
||||
from mmpretrain.evaluation import Accuracy
|
||||
|
||||
# dataset settings
|
||||
dataset_type = CUB
|
||||
data_preprocessor = dict(
|
||||
num_classes=200,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(type=Resize, scale=510),
|
||||
dict(type=RandomCrop, crop_size=384),
|
||||
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(type=Resize, scale=510),
|
||||
dict(type=CenterCrop, crop_size=384),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=2,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/CUB_200_2011',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type=DefaultSampler, shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=2,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/CUB_200_2011',
|
||||
split='test',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type=DefaultSampler, shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type=Accuracy, topk=(1, ))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.dataset import DefaultSampler
|
||||
|
||||
from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile,
|
||||
PackInputs, RandAugment, RandomErasing,
|
||||
RandomFlip, RandomResizedCrop, ResizeEdge)
|
||||
from mmpretrain.evaluation import Accuracy
|
||||
|
||||
# dataset settings
|
||||
dataset_type = ImageNet
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
bgr_mean = data_preprocessor['mean'][::-1]
|
||||
bgr_std = data_preprocessor['std'][::-1]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(
|
||||
type=RandomResizedCrop,
|
||||
scale=256,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type=RandAugment,
|
||||
policies='timm_increasing',
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
||||
dict(
|
||||
type=RandomErasing,
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type=LoadImageFromFile),
|
||||
dict(
|
||||
type=ResizeEdge,
|
||||
scale=292, # ( 256 / 224 * 256 )
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type=CenterCrop, crop_size=256),
|
||||
dict(type=PackInputs),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type=DefaultSampler, shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
split='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type=DefaultSampler, shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type=Accuracy, topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling,
|
||||
ImageClassifier, LinearClsHead, SwinTransformer)
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type=ImageClassifier,
|
||||
backbone=dict(
|
||||
type=SwinTransformer,
|
||||
arch='base',
|
||||
img_size=384,
|
||||
stage_cfgs=dict(block_cfgs=dict(window_size=12))),
|
||||
neck=dict(type=GlobalAveragePooling),
|
||||
head=dict(
|
||||
type=LinearClsHead,
|
||||
num_classes=1000,
|
||||
in_channels=1024,
|
||||
loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
|
||||
topk=(1, 5)))
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmpretrain.models import (GlobalAveragePooling, ImageClassifier,
|
||||
LabelSmoothLoss, LinearClsHead,
|
||||
SwinTransformerV2)
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type=ImageClassifier,
|
||||
backbone=dict(
|
||||
type=SwinTransformerV2, arch='base', img_size=384, drop_path_rate=0.2),
|
||||
neck=dict(type=GlobalAveragePooling),
|
||||
head=dict(
|
||||
type=LinearClsHead,
|
||||
num_classes=1000,
|
||||
in_channels=1024,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False))
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.optim import CosineAnnealingLR, LinearLR
|
||||
from torch.optim import SGD
|
||||
|
||||
# optimizer
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type=SGD, lr=0.01, momentum=0.9, weight_decay=0.0005, nesterov=True))
|
||||
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
# warm up learning rate scheduler
|
||||
dict(
|
||||
type=LinearLR,
|
||||
start_factor=0.01,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=5,
|
||||
# update by iter
|
||||
convert_to_iter_based=True),
|
||||
# main learning rate scheduler
|
||||
dict(
|
||||
type=CosineAnnealingLR,
|
||||
T_max=95,
|
||||
by_epoch=True,
|
||||
begin=5,
|
||||
end=100,
|
||||
)
|
||||
]
|
||||
|
||||
# train, val, test setting
|
||||
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=64)
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, LabelSmoothLoss, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_224 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(img_size=224, drop_path_rate=0.5, stage_cfgs=None),
|
||||
head=dict(
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type=LabelSmoothLoss,
|
||||
label_smooth_val=0.1,
|
||||
mode='original',
|
||||
loss_weight=0),
|
||||
topk=None,
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_384 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_224 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(arch='large', img_size=224, stage_cfgs=None),
|
||||
head=dict(in_channels=1536),
|
||||
)
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_384 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(arch='large'),
|
||||
head=dict(in_channels=1536),
|
||||
)
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.hooks import CheckpointHook, LoggerHook
|
||||
from mmengine.model import PretrainedInit
|
||||
from torch.optim.adamw import AdamW
|
||||
|
||||
from mmpretrain.models import ImageClassifier
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.cub_bs8_384 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.cub_bs64 import *
|
||||
|
||||
# model settings
|
||||
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin-large_3rdparty_in21k-384px.pth' # noqa
|
||||
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='large',
|
||||
init_cfg=dict(
|
||||
type=PretrainedInit, checkpoint=checkpoint, prefix='backbone')),
|
||||
head=dict(num_classes=200, in_channels=1536))
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
_delete_=True,
|
||||
type=AdamW,
|
||||
lr=5e-6,
|
||||
weight_decay=0.0005,
|
||||
eps=1e-8,
|
||||
betas=(0.9, 0.999)),
|
||||
paramwise_cfg=dict(
|
||||
norm_decay_mult=0.0,
|
||||
bias_decay_mult=0.0,
|
||||
custom_keys={
|
||||
'.absolute_pos_embed': dict(decay_mult=0.0),
|
||||
'.relative_position_bias_table': dict(decay_mult=0.0)
|
||||
}),
|
||||
clip_grad=dict(max_norm=5.0),
|
||||
)
|
||||
|
||||
default_hooks = dict(
|
||||
# log every 20 intervals
|
||||
logger=dict(type=LoggerHook, interval=20),
|
||||
# save last three checkpoints
|
||||
checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=3))
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, LabelSmoothLoss, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_224 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='small', img_size=224, drop_path_rate=0.3, stage_cfgs=None),
|
||||
head=dict(
|
||||
in_channels=768,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type=LabelSmoothLoss,
|
||||
label_smooth_val=0.1,
|
||||
mode='original',
|
||||
loss_weight=0),
|
||||
topk=None,
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, LabelSmoothLoss, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_224 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='tiny', img_size=224, drop_path_rate=0.2, stage_cfgs=None),
|
||||
head=dict(
|
||||
in_channels=768,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type=LabelSmoothLoss,
|
||||
label_smooth_val=0.1,
|
||||
mode='original',
|
||||
loss_weight=0),
|
||||
topk=None,
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet21k_bs128 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
img_size=192, drop_path_rate=0.5, window_size=[12, 12, 12, 6]),
|
||||
head=dict(num_classes=21841),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
||||
|
||||
# dataset settings
|
||||
data_preprocessor = dict(num_classes=21841)
|
||||
|
||||
_base_['train_pipeline'][1]['scale'] = 192 # RandomResizedCrop
|
||||
_base_['test_pipeline'][1]['scale'] = 219 # ResizeEdge
|
||||
_base_['test_pipeline'][2]['crop_size'] = 192 # CenterCrop
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
img_size=256, drop_path_rate=0.5, window_size=[16, 16, 16, 8]),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
img_size=256,
|
||||
window_size=[16, 16, 16, 8],
|
||||
pretrained_window_sizes=[12, 12, 12, 6]),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_384 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
window_size=[24, 24, 24, 12], pretrained_window_sizes=[12, 12, 12, 6]))
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(img_size=256, drop_path_rate=0.5),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet21k_bs128 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
img_size=192, drop_path_rate=0.5, window_size=[12, 12, 12, 6]),
|
||||
head=dict(num_classes=21841),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
||||
|
||||
# dataset settings
|
||||
data_preprocessor = dict(num_classes=21841)
|
||||
|
||||
_base_['train_pipeline'][1]['scale'] = 192 # RandomResizedCrop
|
||||
_base_['test_pipeline'][1]['scale'] = 219 # ResizeEdge
|
||||
_base_['test_pipeline'][2]['crop_size'] = 192 # CenterCrop
|
|
@ -0,0 +1,24 @@
|
|||
# Only for evaluation
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
|
||||
from mmpretrain.models import CrossEntropyLoss
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='large',
|
||||
img_size=256,
|
||||
window_size=[16, 16, 16, 8],
|
||||
pretrained_window_sizes=[12, 12, 12, 6]),
|
||||
head=dict(
|
||||
in_channels=1536,
|
||||
loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
|
||||
topk=(1, 5)))
|
|
@ -0,0 +1,24 @@
|
|||
# Only for evaluation
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
|
||||
from mmpretrain.models import CrossEntropyLoss
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_384 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='large',
|
||||
img_size=384,
|
||||
window_size=[24, 24, 24, 12],
|
||||
pretrained_window_sizes=[12, 12, 12, 6]),
|
||||
head=dict(
|
||||
in_channels=1536,
|
||||
loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
|
||||
topk=(1, 5)))
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='small',
|
||||
img_size=256,
|
||||
drop_path_rate=0.3,
|
||||
window_size=[16, 16, 16, 8]),
|
||||
head=dict(in_channels=768),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(arch='small', img_size=256, drop_path_rate=0.3),
|
||||
head=dict(in_channels=768),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(
|
||||
arch='tiny',
|
||||
img_size=256,
|
||||
drop_path_rate=0.2,
|
||||
window_size=[16, 16, 16, 8]),
|
||||
head=dict(in_channels=768),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# This is a BETA new format config file, and the usage may change recently.
|
||||
from mmengine.config import read_base
|
||||
from mmengine.model import ConstantInit, TruncNormalInit
|
||||
|
||||
from mmpretrain.models import CutMix, Mixup
|
||||
|
||||
with read_base():
|
||||
from .._base_.datasets.imagenet_bs64_swin_256 import *
|
||||
from .._base_.default_runtime import *
|
||||
from .._base_.models.swin_transformer_v2_base import *
|
||||
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
|
||||
|
||||
# model settings
|
||||
model.update(
|
||||
backbone=dict(arch='tiny', img_size=256, drop_path_rate=0.2),
|
||||
head=dict(in_channels=768),
|
||||
init_cfg=[
|
||||
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
|
||||
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(
|
||||
augments=[dict(type=Mixup, alpha=0.8),
|
||||
dict(type=CutMix, alpha=1.0)]))
|
Loading…
Reference in New Issue