[CodeCamp2023-338] New Version of config Adapting Swin Transformer Algorithm

pull/1780/head
John 2023-08-31 18:15:47 +08:00
parent e1675e893e
commit 634852ad61
35 changed files with 834 additions and 0 deletions

View File

@ -0,0 +1,59 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.dataset import DefaultSampler
from mmpretrain.datasets import (CUB, CenterCrop, LoadImageFromFile,
PackInputs, RandomCrop, RandomFlip, Resize)
from mmpretrain.evaluation import Accuracy
# dataset settings
dataset_type = CUB
data_preprocessor = dict(
num_classes=200,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
train_pipeline = [
dict(type=LoadImageFromFile),
dict(type=Resize, scale=510),
dict(type=RandomCrop, crop_size=384),
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
dict(type=PackInputs),
]
test_pipeline = [
dict(type=LoadImageFromFile),
dict(type=Resize, scale=510),
dict(type=CenterCrop, crop_size=384),
dict(type=PackInputs),
]
train_dataloader = dict(
batch_size=8,
num_workers=2,
dataset=dict(
type=dataset_type,
data_root='data/CUB_200_2011',
split='train',
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
)
val_dataloader = dict(
batch_size=8,
num_workers=2,
dataset=dict(
type=dataset_type,
data_root='data/CUB_200_2011',
split='test',
pipeline=test_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
)
val_evaluator = dict(type=Accuracy, topk=(1, ))
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,89 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.dataset import DefaultSampler
from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile,
PackInputs, RandAugment, RandomErasing,
RandomFlip, RandomResizedCrop, ResizeEdge)
from mmpretrain.evaluation import Accuracy
# dataset settings
dataset_type = ImageNet
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
bgr_mean = data_preprocessor['mean'][::-1]
bgr_std = data_preprocessor['std'][::-1]
train_pipeline = [
dict(type=LoadImageFromFile),
dict(
type=RandomResizedCrop,
scale=256,
backend='pillow',
interpolation='bicubic'),
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
dict(
type=RandAugment,
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(
type=RandomErasing,
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type=PackInputs),
]
test_pipeline = [
dict(type=LoadImageFromFile),
dict(
type=ResizeEdge,
scale=292, # ( 256 / 224 * 256 )
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type=CenterCrop, crop_size=256),
dict(type=PackInputs),
]
train_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
)
val_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
)
val_evaluator = dict(type=Accuracy, topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=384, backend='pillow', interpolation='bicubic'),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.model import ConstantInit, TruncNormalInit
from mmpretrain.models import (CutMix, GlobalAveragePooling, ImageClassifier,
LabelSmoothLoss, LinearClsHead, Mixup,
SwinTransformer)
# model settings
model = dict(
type=ImageClassifier,
backbone=dict(
type=SwinTransformer, arch='base', img_size=224, drop_path_rate=0.5),
neck=dict(type=GlobalAveragePooling),
head=dict(
type=LinearClsHead,
num_classes=1000,
in_channels=1024,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(
augments=[dict(type=Mixup, alpha=0.8),
dict(type=CutMix, alpha=1.0)]),
)

View File

@ -0,0 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling,
ImageClassifier, LinearClsHead, SwinTransformer)
# model settings
# Only for evaluation
model = dict(
type=ImageClassifier,
backbone=dict(
type=SwinTransformer,
arch='base',
img_size=384,
stage_cfgs=dict(block_cfgs=dict(window_size=12))),
neck=dict(type=GlobalAveragePooling),
head=dict(
type=LinearClsHead,
num_classes=1000,
in_channels=1024,
loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
topk=(1, 5)))

View File

@ -0,0 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling,
ImageClassifier, LinearClsHead, SwinTransformer)
# model settings
# Only for evaluation
model = dict(
type=ImageClassifier,
backbone=dict(type=SwinTransformer, arch='large', img_size=224),
neck=dict(type=GlobalAveragePooling),
head=dict(
type=LinearClsHead,
num_classes=1000,
in_channels=1536,
loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
topk=(1, 5)))

View File

@ -0,0 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling,
ImageClassifier, LinearClsHead, SwinTransformer)
# model settings
# Only for evaluation
model = dict(
type=ImageClassifier,
backbone=dict(
type=SwinTransformer,
arch='large',
img_size=384,
stage_cfgs=dict(block_cfgs=dict(window_size=12))),
neck=dict(type=GlobalAveragePooling),
head=dict(
type=LinearClsHead,
num_classes=1000,
in_channels=1536,
loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
topk=(1, 5)))

View File

@ -0,0 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.model import ConstantInit, TruncNormalInit
from mmpretrain.models import (CutMix, GlobalAveragePooling, ImageClassifier,
LabelSmoothLoss, LinearClsHead, Mixup,
SwinTransformer)
# model settings
model = dict(
type=ImageClassifier,
backbone=dict(
type=SwinTransformer, arch='small', img_size=224, drop_path_rate=0.3),
neck=dict(type=GlobalAveragePooling),
head=dict(
type=LinearClsHead,
num_classes=1000,
in_channels=768,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(
augments=[dict(type=Mixup, alpha=0.8),
dict(type=CutMix, alpha=1.0)]),
)

View File

@ -0,0 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.model import ConstantInit, TruncNormalInit
from mmpretrain.models import (CutMix, GlobalAveragePooling, ImageClassifier,
LabelSmoothLoss, LinearClsHead, Mixup,
SwinTransformer)
# model settings
model = dict(
type=ImageClassifier,
backbone=dict(
type=SwinTransformer, arch='tiny', img_size=224, drop_path_rate=0.2),
neck=dict(type=GlobalAveragePooling),
head=dict(
type=LinearClsHead,
num_classes=1000,
in_channels=768,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(
augments=[dict(type=Mixup, alpha=0.8),
dict(type=CutMix, alpha=1.0)]),
)

View File

@ -0,0 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.model import ConstantInit, TruncNormalInit
from mmpretrain.models import (CutMix, GlobalAveragePooling, ImageClassifier,
LabelSmoothLoss, LinearClsHead, Mixup,
SwinTransformerV2)
# model settings
model = dict(
type=ImageClassifier,
backbone=dict(
type=SwinTransformerV2, arch='base', img_size=256, drop_path_rate=0.5),
neck=dict(type=GlobalAveragePooling),
head=dict(
type=LinearClsHead,
num_classes=1000,
in_channels=1024,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(
augments=[dict(type=Mixup, alpha=0.8),
dict(type=CutMix, alpha=1.0)]),
)

View File

@ -0,0 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmpretrain.models import (GlobalAveragePooling, ImageClassifier,
LabelSmoothLoss, LinearClsHead,
SwinTransformerV2)
# model settings
model = dict(
type=ImageClassifier,
backbone=dict(
type=SwinTransformerV2, arch='base', img_size=384, drop_path_rate=0.2),
neck=dict(type=GlobalAveragePooling),
head=dict(
type=LinearClsHead,
num_classes=1000,
in_channels=1024,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'),
cal_acc=False))

View File

@ -0,0 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling,
ImageClassifier, LinearClsHead,
SwinTransformerV2)
# model settings
# Only for evaluation
model = dict(
type=ImageClassifier,
backbone=dict(
type=SwinTransformerV2, arch='large', img_size=256,
drop_path_rate=0.2),
neck=dict(type=GlobalAveragePooling),
head=dict(
type=LinearClsHead,
num_classes=1000,
in_channels=1536,
loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
topk=(1, 5)))

View File

@ -0,0 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling,
ImageClassifier, LinearClsHead,
SwinTransformerV2)
# model settings
# Only for evaluation
model = dict(
type=ImageClassifier,
backbone=dict(
type=SwinTransformerV2, arch='large', img_size=384,
drop_path_rate=0.2),
neck=dict(type=GlobalAveragePooling),
head=dict(
type=LinearClsHead,
num_classes=1000,
in_channels=1536,
loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
topk=(1, 5)))

View File

@ -0,0 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.model import ConstantInit, TruncNormalInit
from mmpretrain.models import (CutMix, GlobalAveragePooling, ImageClassifier,
LabelSmoothLoss, LinearClsHead, Mixup,
SwinTransformerV2)
# model settings
model = dict(
type=ImageClassifier,
backbone=dict(
type=SwinTransformerV2, arch='small', img_size=256,
drop_path_rate=0.3),
neck=dict(type=GlobalAveragePooling),
head=dict(
type=LinearClsHead,
num_classes=1000,
in_channels=768,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(
augments=[dict(type=Mixup, alpha=0.8),
dict(type=CutMix, alpha=1.0)]),
)

View File

@ -0,0 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.model import ConstantInit, TruncNormalInit
from mmpretrain.models import (CutMix, GlobalAveragePooling, ImageClassifier,
LabelSmoothLoss, LinearClsHead, Mixup,
SwinTransformerV2)
# model settings
model = dict(
type=ImageClassifier,
backbone=dict(
type=SwinTransformerV2, arch='tiny', img_size=256, drop_path_rate=0.2),
neck=dict(type=GlobalAveragePooling),
head=dict(
type=LinearClsHead,
num_classes=1000,
in_channels=768,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(type=LabelSmoothLoss, label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type=TruncNormalInit, layer='Linear', std=0.02, bias=0.),
dict(type=ConstantInit, layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(
augments=[dict(type=Mixup, alpha=0.8),
dict(type=CutMix, alpha=1.0)]),
)

View File

@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.optim import CosineAnnealingLR, LinearLR
from torch.optim import SGD
# optimizer
optim_wrapper = dict(
optimizer=dict(
type=SGD, lr=0.01, momentum=0.9, weight_decay=0.0005, nesterov=True))
# learning policy
param_scheduler = [
# warm up learning rate scheduler
dict(
type=LinearLR,
start_factor=0.01,
by_epoch=True,
begin=0,
end=5,
# update by iter
convert_to_iter_based=True),
# main learning rate scheduler
dict(
type=CosineAnnealingLR,
T_max=95,
by_epoch=True,
begin=5,
end=100,
)
]
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)
val_cfg = dict()
test_cfg = dict()
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=64)

View File

@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
with read_base():
from .._base_.datasets.imagenet_bs64_swin_384 import *
from .._base_.default_runtime import *
from .._base_.models.swin_transformer.base_384 import *
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
# schedule settings
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))

View File

@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
with read_base():
from .._base_.datasets.imagenet_bs64_swin_224 import *
from .._base_.default_runtime import *
from .._base_.models.swin_transformer.base_224 import *
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
# schedule settings
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))

View File

@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
with read_base():
from .._base_.datasets.imagenet_bs64_swin_384 import *
from .._base_.default_runtime import *
from .._base_.models.swin_transformer.large_384 import *
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
# schedule settings
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))

View File

@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
with read_base():
from .._base_.datasets.imagenet_bs64_swin_224 import *
from .._base_.default_runtime import *
from .._base_.models.swin_transformer.large_224 import *
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
# schedule settings
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))

View File

@ -0,0 +1,48 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
from mmengine.hooks import CheckpointHook, LoggerHook
from mmengine.model import PretrainedInit
from torch.optim.adamw import AdamW
from mmpretrain.models import ImageClassifier
with read_base():
from .._base_.datasets.cub_bs8_384 import *
from .._base_.default_runtime import *
from .._base_.models.swin_transformer.large_384 import *
from .._base_.schedules.cub_bs64 import *
# model settings
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin-large_3rdparty_in21k-384px.pth' # noqa
model = dict(
type=ImageClassifier,
backbone=dict(
init_cfg=dict(
type=PretrainedInit, checkpoint=checkpoint, prefix='backbone')),
head=dict(num_classes=200, ))
# schedule settings
optim_wrapper = dict(
optimizer=dict(
_delete_=True,
type=AdamW,
lr=5e-6,
weight_decay=0.0005,
eps=1e-8,
betas=(0.9, 0.999)),
paramwise_cfg=dict(
norm_decay_mult=0.0,
bias_decay_mult=0.0,
custom_keys={
'.absolute_pos_embed': dict(decay_mult=0.0),
'.relative_position_bias_table': dict(decay_mult=0.0)
}),
clip_grad=dict(max_norm=5.0),
)
default_hooks = dict(
# log every 20 intervals
logger=dict(type=LoggerHook, interval=20),
# save last three checkpoints
checkpoint=dict(type=CheckpointHook, interval=1, max_keep_ckpts=3))

View File

@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
with read_base():
from .._base_.datasets.imagenet_bs64_swin_224 import *
from .._base_.default_runtime import *
from .._base_.models.swin_transformer.small_224 import *
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
# schedule settings
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))

View File

@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
with read_base():
from .._base_.datasets.imagenet_bs64_swin_224 import *
from .._base_.default_runtime import *
from .._base_.models.swin_transformer.tiny_224 import *
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
# schedule settings
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))

View File

@ -0,0 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
with read_base():
from .._base_.datasets.imagenet21k_bs128 import *
from .._base_.default_runtime import *
from .._base_.models.swin_transformer_v2.base_256 import *
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
# model settings
model = dict(
backbone=dict(img_size=192, window_size=[12, 12, 12, 6]),
head=dict(num_classes=21841),
)
# dataset settings
data_preprocessor = dict(num_classes=21841)
_base_['train_pipeline'][1]['scale'] = 192 # RandomResizedCrop
_base_['test_pipeline'][1]['scale'] = 219 # ResizeEdge
_base_['test_pipeline'][2]['crop_size'] = 192 # CenterCrop

View File

@ -0,0 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
with read_base():
from .._base_.datasets.imagenet_bs64_swin_256 import *
from .._base_.default_runtime import *
from .._base_.models.swin_transformer_v2.base_256 import *
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
model = dict(backbone=dict(window_size=[16, 16, 16, 8]))

View File

@ -0,0 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
from mmpretrain.models import ImageClassifier
with read_base():
from .._base_.datasets.imagenet_bs64_swin_256 import *
from .._base_.default_runtime import *
from .._base_.models.swin_transformer_v2.base_256 import *
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
model = dict(
type=ImageClassifier,
backbone=dict(
window_size=[16, 16, 16, 8],
drop_path_rate=0.2,
pretrained_window_sizes=[12, 12, 12, 6]))

View File

@ -0,0 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
from mmpretrain.models import ImageClassifier
with read_base():
from .._base_.datasets.imagenet_bs64_swin_384 import *
from .._base_.default_runtime import *
from .._base_.models.swin_transformer_v2.base_384 import *
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
model = dict(
type=ImageClassifier,
backbone=dict(
img_size=384,
window_size=[24, 24, 24, 12],
drop_path_rate=0.2,
pretrained_window_sizes=[12, 12, 12, 6]))

View File

@ -0,0 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
with read_base():
from .._base_.datasets.imagenet_bs64_swin_256 import *
from .._base_.default_runtime import *
from .._base_.models.swin_transformer_v2.base_256 import *
from .._base_.schedules.imagenet_bs1024_adamw_swin import *

View File

@ -0,0 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
with read_base():
from .._base_.datasets.imagenet21k_bs128 import *
from .._base_.default_runtime import *
from .._base_.models.swin_transformer_v2.base_256 import *
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
# model settings
model = dict(
backbone=dict(img_size=192, window_size=[12, 12, 12, 6]),
head=dict(num_classes=21841),
)
# dataset settings
data_preprocessor = dict(num_classes=21841)
_base_['train_pipeline'][1]['scale'] = 192 # RandomResizedCrop
_base_['test_pipeline'][1]['scale'] = 219 # ResizeEdge
_base_['test_pipeline'][2]['crop_size'] = 192 # CenterCrop

View File

@ -0,0 +1,18 @@
# Only for evaluation
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
from mmpretrain.models import ImageClassifier
with read_base():
from .._base_.datasets.imagenet_bs64_swin_256 import *
from .._base_.default_runtime import *
from .._base_.models.swin_transformer_v2.large_256 import *
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
model = dict(
type=ImageClassifier,
backbone=dict(
window_size=[16, 16, 16, 8], pretrained_window_sizes=[12, 12, 12, 6]),
)

View File

@ -0,0 +1,20 @@
# Only for evaluation
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
from mmpretrain.models import ImageClassifier
with read_base():
from .._base_.datasets.imagenet_bs64_swin_384 import *
from .._base_.default_runtime import *
from .._base_.models.swin_transformer_v2.large_384 import *
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
model = dict(
type=ImageClassifier,
backbone=dict(
img_size=384,
window_size=[24, 24, 24, 12],
pretrained_window_sizes=[12, 12, 12, 6]),
)

View File

@ -0,0 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
with read_base():
from .._base_.datasets.imagenet_bs64_swin_256 import *
from .._base_.default_runtime import *
from .._base_.models.swin_transformer_v2.small_256 import *
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
model = dict(backbone=dict(window_size=[16, 16, 16, 8]))

View File

@ -0,0 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
with read_base():
from .._base_.datasets.imagenet_bs64_swin_256 import *
from .._base_.default_runtime import *
from .._base_.models.swin_transformer_v2.small_256 import *
from .._base_.schedules.imagenet_bs1024_adamw_swin import *

View File

@ -0,0 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
with read_base():
from .._base_.datasets.imagenet_bs64_swin_256 import *
from .._base_.default_runtime import *
from .._base_.models.swin_transformer_v2.tiny_256 import *
from .._base_.schedules.imagenet_bs1024_adamw_swin import *
model = dict(backbone=dict(window_size=[16, 16, 16, 8]))

View File

@ -0,0 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
with read_base():
from .._base_.datasets.imagenet_bs64_swin_256 import *
from .._base_.default_runtime import *
from .._base_.models.swin_transformer_v2.tiny_256 import *
from .._base_.schedules.imagenet_bs1024_adamw_swin import *