[CodeCamp2023-340] New Version of config Adapting MobileNet Algorithm (#1774)

* add new config adapting MobileNetV2,V3

* add base model config for mobile net v3, modified all training configs of mobile net v3 inherit from the base model config

* removed directory _base_/models/mobilenet_v3
pull/1906/head
DE009 2023-09-01 17:54:18 +08:00 committed by GitHub
parent d2ccc44a2c
commit 845b462190
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 554 additions and 0 deletions

View File

@ -0,0 +1,52 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.dataset import DefaultSampler
from mmpretrain.datasets import CIFAR10, PackInputs, RandomCrop, RandomFlip
from mmpretrain.evaluation import Accuracy
# dataset settings
dataset_type = CIFAR10
data_preprocessor = dict(
num_classes=10,
# RGB format normalization parameters
mean=[125.307, 122.961, 113.8575],
std=[51.5865, 50.847, 51.255],
# loaded images are already RGB format
to_rgb=False)
train_pipeline = [
dict(type=RandomCrop, crop_size=32, padding=4),
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
dict(type=PackInputs),
]
test_pipeline = [
dict(type=PackInputs),
]
train_dataloader = dict(
batch_size=16,
num_workers=2,
dataset=dict(
type=dataset_type,
data_root='data/cifar10',
split='train',
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
)
val_dataloader = dict(
batch_size=16,
num_workers=2,
dataset=dict(
type=dataset_type,
data_root='data/cifar10/',
split='test',
pipeline=test_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
)
val_evaluator = dict(type=Accuracy, topk=(1, ))
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,75 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.dataset import DefaultSampler
from mmpretrain.datasets import (AutoAugment, CenterCrop, ImageNet,
LoadImageFromFile, PackInputs, RandomErasing,
RandomFlip, RandomResizedCrop, ResizeEdge)
from mmpretrain.evaluation import Accuracy
# dataset settings
dataset_type = ImageNet
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
bgr_mean = data_preprocessor['mean'][::-1]
bgr_std = data_preprocessor['std'][::-1]
train_pipeline = [
dict(type=LoadImageFromFile),
dict(type=RandomResizedCrop, scale=224, backend='pillow'),
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
dict(
type=AutoAugment,
policies='imagenet',
hparams=dict(pad_val=[round(x) for x in bgr_mean])),
dict(
type=RandomErasing,
erase_prob=0.2,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type=PackInputs),
]
test_pipeline = [
dict(type=LoadImageFromFile),
dict(type=ResizeEdge, scale=256, edge='short', backend='pillow'),
dict(type=CenterCrop, crop_size=224),
dict(type=PackInputs),
]
train_dataloader = dict(
batch_size=128,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
)
val_dataloader = dict(
batch_size=128,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
)
val_evaluator = dict(type=Accuracy, topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,60 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.dataset import DefaultSampler
from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile,
PackInputs, RandomFlip, RandomResizedCrop,
ResizeEdge)
from mmpretrain.evaluation import Accuracy
# dataset settings
dataset_type = ImageNet
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
train_pipeline = [
dict(type=LoadImageFromFile),
dict(type=RandomResizedCrop, scale=224, backend='pillow'),
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
dict(type=PackInputs),
]
test_pipeline = [
dict(type=LoadImageFromFile),
dict(type=ResizeEdge, scale=256, edge='short', backend='pillow'),
dict(type=CenterCrop, crop_size=224),
dict(type=PackInputs),
]
train_dataloader = dict(
batch_size=32,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='train',
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=True),
)
val_dataloader = dict(
batch_size=32,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
split='val',
pipeline=test_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
)
val_evaluator = dict(type=Accuracy, topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling,
ImageClassifier, LinearClsHead, MobileNetV2)
# model settings
model = dict(
type=ImageClassifier,
backbone=dict(type=MobileNetV2, widen_factor=1.0),
neck=dict(type=GlobalAveragePooling),
head=dict(
type=LinearClsHead,
num_classes=1000,
in_channels=1280,
loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,25 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.model.weight_init import NormalInit
from torch.nn.modules.activation import Hardswish
from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling,
ImageClassifier, MobileNetV3,
StackedLinearClsHead)
# model settings
model = dict(
type=ImageClassifier,
backbone=dict(type=MobileNetV3, arch='small'),
neck=dict(type=GlobalAveragePooling),
head=dict(
type=StackedLinearClsHead,
num_classes=1000,
in_channels=576,
mid_channels=[1024],
dropout_rate=0.2,
act_cfg=dict(type=Hardswish),
loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
init_cfg=dict(
type=NormalInit, layer='Linear', mean=0., std=0.01, bias=0.),
topk=(1, 5)))

View File

@ -0,0 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.optim import MultiStepLR
from torch.optim import SGD
# optimizer
optim_wrapper = dict(
optimizer=dict(type=SGD, lr=0.1, momentum=0.9, weight_decay=0.0001))
# learning policy
param_scheduler = dict(
type=MultiStepLR, by_epoch=True, milestones=[100, 150], gamma=0.1)
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=200, val_interval=1)
val_cfg = dict()
test_cfg = dict()
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=128)

View File

@ -0,0 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.optim import StepLR
from torch.optim import SGD
# optimizer
optim_wrapper = dict(
optimizer=dict(type=SGD, lr=0.045, momentum=0.9, weight_decay=0.00004))
# learning policy
param_scheduler = dict(type=StepLR, by_epoch=True, step_size=1, gamma=0.98)
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1)
val_cfg = dict()
test_cfg = dict()
# NOTE: `auto_scale_lr` is for automatically scaling LR,
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=256)

View File

@ -0,0 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
with read_base():
from .._base_.datasets.imagenet_bs32_pil_resize import *
from .._base_.default_runtime import *
from .._base_.models.mobilenet_v2_1x import *
from .._base_.schedules.imagenet_bs256_epochstep import *

View File

@ -0,0 +1,40 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification
from mmengine.config import read_base
with read_base():
from .._base_.models.mobilenet_v3_small import *
from .._base_.datasets.imagenet_bs128_mbv3 import *
from .._base_.default_runtime import *
from mmengine.optim import StepLR
from torch.optim import RMSprop
# model settings
model.merge(
dict(
backbone=dict(arch='large'),
head=dict(in_channels=960, mid_channels=[1280]),
))
# schedule settings
optim_wrapper = dict(
optimizer=dict(
type=RMSprop,
lr=0.064,
alpha=0.9,
momentum=0.9,
eps=0.0316,
weight_decay=1e-5))
param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973)
train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=1)
val_cfg = dict()
test_cfg = dict()
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
# base_batch_size = (8 GPUs) x (128 samples per GPU)
auto_scale_lr = dict(base_batch_size=1024)

View File

@ -0,0 +1,85 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification
from mmengine.config import read_base
with read_base():
from .._base_.models.mobilenet_v3_small import *
from .._base_.datasets.imagenet_bs128_mbv3 import *
from .._base_.default_runtime import *
from mmengine.optim import StepLR
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.optim import RMSprop
# model settings
model.merge(
dict(
backbone=dict(
arch='small_050',
norm_cfg=dict(type=BatchNorm2d, eps=1e-5, momentum=0.1)),
head=dict(in_channels=288),
))
train_pipeline = [
dict(type=LoadImageFromFile),
dict(
type=RandomResizedCrop,
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
dict(
type=AutoAugment,
policies='imagenet',
hparams=dict(pad_val=[round(x) for x in [103.53, 116.28, 123.675]])),
dict(
type=RandomErasing,
erase_prob=0.2,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=[103.53, 116.28, 123.675],
fill_std=[57.375, 57.12, 58.395]),
dict(type=PackInputs),
]
test_pipeline = [
dict(type=LoadImageFromFile),
dict(
type=ResizeEdge,
scale=256,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type=CenterCrop, crop_size=224),
dict(type=PackInputs),
]
train_dataloader.merge(dict(dataset=dict(pipeline=train_pipeline)))
val_dataloader.merge(dict(dataset=dict(pipeline=test_pipeline)))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
# schedule settings
optim_wrapper = dict(
optimizer=dict(
type=RMSprop,
lr=0.064,
alpha=0.9,
momentum=0.9,
eps=0.0316,
weight_decay=1e-5))
param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973)
train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=10)
val_cfg = dict()
test_cfg = dict()
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
# base_batch_size = (8 GPUs) x (128 samples per GPU)
auto_scale_lr = dict(base_batch_size=1024)

View File

@ -0,0 +1,83 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification
from mmengine.config import read_base
with read_base():
from .._base_.models.mobilenet_v3_small import *
from .._base_.datasets.imagenet_bs128_mbv3 import *
from .._base_.default_runtime import *
from mmengine.optim import StepLR
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.optim import RMSprop
# model settings
model.merge(
dict(
backbone=dict(
arch='small_075',
norm_cfg=dict(type=BatchNorm2d, eps=1e-5, momentum=0.1)),
head=dict(in_channels=432),
))
train_pipeline = [
dict(type=LoadImageFromFile),
dict(
type=RandomResizedCrop,
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type=RandomFlip, prob=0.5, direction='horizontal'),
dict(
type=AutoAugment,
policies='imagenet',
hparams=dict(pad_val=[round(x) for x in [103.53, 116.28, 123.675]])),
dict(
type=RandomErasing,
erase_prob=0.2,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=[103.53, 116.28, 123.675],
fill_std=[57.375, 57.12, 58.395]),
dict(type=PackInputs),
]
test_pipeline = [
dict(type=LoadImageFromFile),
dict(
type=ResizeEdge,
scale=256,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type=CenterCrop, crop_size=224),
dict(type=PackInputs),
]
train_dataloader.merge(dict(dataset=dict(pipeline=train_pipeline)))
val_dataloader.merge(dict(dataset=dict(pipeline=test_pipeline)))
test_dataloader = val_dataloader
# schedule settings
optim_wrapper = dict(
optimizer=dict(
type=RMSprop,
lr=0.064,
alpha=0.9,
momentum=0.9,
eps=0.0316,
weight_decay=1e-5))
param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973)
train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=10)
val_cfg = dict()
test_cfg = dict()
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
# base_batch_size = (8 GPUs) x (128 samples per GPU)
auto_scale_lr = dict(base_batch_size=1024)

View File

@ -0,0 +1,34 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification
from mmengine.config import read_base
with read_base():
from .._base_.models.mobilenet_v3_small import *
from .._base_.datasets.imagenet_bs128_mbv3 import *
from .._base_.default_runtime import *
from mmengine.optim import StepLR
from torch.optim import RMSprop
# schedule settings
optim_wrapper = dict(
optimizer=dict(
type=RMSprop,
lr=0.064,
alpha=0.9,
momentum=0.9,
eps=0.0316,
weight_decay=1e-5))
param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973)
train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=1)
val_cfg = dict()
test_cfg = dict()
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
# base_batch_size = (8 GPUs) x (128 samples per GPU)
auto_scale_lr = dict(base_batch_size=1024)

View File

@ -0,0 +1,34 @@
# Copyright (c) OpenMMLab. All rights reserved.
# This is a BETA new format config file, and the usage may change recently.
from mmengine.config import read_base
with read_base():
from .._base_.models.mobilenet_v3_small import *
from .._base_.datasets.cifar10_bs16 import *
from .._base_.schedules.cifar10_bs128 import *
from .._base_.default_runtime import *
from mmengine.optim import MultiStepLR
# model settings
model.merge(
dict(
head=dict(
_delete_=True,
type=StackedLinearClsHead,
num_classes=10,
in_channels=576,
mid_channels=[1280],
act_cfg=dict(type=Hardswish),
loss=dict(type=CrossEntropyLoss, loss_weight=1.0),
topk=(1, 5))))
# schedule settings
param_scheduler.merge(
dict(
type=MultiStepLR,
by_epoch=True,
milestones=[120, 170],
gamma=0.1,
))
train_cfg.merge(dict(by_epoch=True, max_epochs=200))