[Improvement] Adapt OFA series with SearchableMobileNetV3 (#385)
* fix mutable bug in AttentiveMobileNetV3 * remove unness code * update ATTENTIVE_SUBNET_A0-A6.yaml with optimized names * unify the sampling usage in sandwich_rule-based NAS * use alias to export subnet * update OFA configs * fix attr bug * fix comments * update convert_supernet2subnet.py * correct the way to dump DerivedMutable * fix convert index bug * update OFA configs & models * fix dynamic2static * generalize convert_ofa_ckpt.py * update input_resizer * update README.md * fix ut * update export_fix_subnet * update _dynamic_to_static * update fix_subnet UT & minor fix bugs * fix ut * add new autoaug compared to attentivenas * clean * fix act * fix act_cfg * update fix_subnet * fix lint * add docstring Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: aptsunny <aptsunny@tongji.edu.cn>pull/400/head
parent
f886821ba1
commit
42e8de73af
|
@ -40,12 +40,8 @@ arch_setting = dict(
|
|||
[1792, 1984, 1984 - 1792], # last layer
|
||||
])
|
||||
|
||||
_INPUT_MUTABLE = dict(
|
||||
input_resizer=dict(type='DynamicInputResizer'),
|
||||
mutable_shape=dict(
|
||||
type='OneShotMutableValue',
|
||||
value_list=[[192, 192], [224, 224], [256, 256], [288, 288]],
|
||||
default_value=[224, 224]))
|
||||
input_resizer_cfg = dict(
|
||||
input_sizes=[[192, 192], [224, 224], [256, 256], [288, 288]])
|
||||
|
||||
nas_backbone = dict(
|
||||
type='AttentiveMobileNetV3',
|
||||
|
|
|
@ -37,12 +37,21 @@ arch_setting = dict(
|
|||
[1024, 1280, 1280 - 1024], # last layer
|
||||
])
|
||||
|
||||
input_resizer_cfg = dict(
|
||||
input_sizes=[[128, 128], [140, 140], [144, 144], [152, 152], [192, 192],
|
||||
[204, 204], [224, 224], [256, 256]])
|
||||
|
||||
nas_backbone = dict(
|
||||
type='mmrazor.AttentiveMobileNetV3',
|
||||
arch_setting=arch_setting,
|
||||
out_indices=(6, ),
|
||||
stride_list=[1, 2, 2, 2, 1, 2],
|
||||
with_se_list=[False, False, True, False, True, True],
|
||||
act_cfg_list=[
|
||||
'HSwish', 'ReLU', 'ReLU', 'ReLU', 'HSwish', 'HSwish', 'HSwish',
|
||||
'HSwish', 'HSwish'
|
||||
],
|
||||
conv_cfg=dict(type='OFAConv2d'),
|
||||
norm_cfg=dict(type='mmrazor.DynamicBatchNorm2d', momentum=0.0),
|
||||
norm_cfg=dict(type='mmrazor.DynamicBatchNorm2d', momentum=0.1),
|
||||
fine_grained_mode=True,
|
||||
with_attentive_shortcut=False)
|
||||
|
|
|
@ -145,7 +145,11 @@ policies = [
|
|||
prob=1.0,
|
||||
magnitude=9,
|
||||
extra_params=extra_params),
|
||||
dict(type='ShearY', prob=0.6, magnitude=3, extra_params=extra_params),
|
||||
dict(
|
||||
type='mmrazor.ShearY',
|
||||
prob=0.6,
|
||||
magnitude=3,
|
||||
extra_params=extra_params),
|
||||
],
|
||||
[
|
||||
dict(
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
# dataset settings
|
||||
dataset_type = 'mmcls.ImageNet'
|
||||
|
||||
# data preprocessor
|
||||
data_preprocessor = dict(
|
||||
type='mmcls.ClsDataPreprocessor',
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
bgr_mean = data_preprocessor['mean'][::-1]
|
||||
bgr_std = data_preprocessor['std'][::-1]
|
||||
|
||||
extra_params = dict(
|
||||
translate_const=int(224 * 0.45),
|
||||
img_mean=tuple(round(x) for x in data_preprocessor['mean']),
|
||||
)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='mmcls.LoadImageFromFile'),
|
||||
dict(type='mmcls.RandomResizedCrop', scale=224),
|
||||
dict(type='mmcls.RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(type='mmcls.ColorJitter', brightness=0.1254, saturation=0.5),
|
||||
dict(type='mmcls.PackClsInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='mmcls.LoadImageFromFile'),
|
||||
dict(
|
||||
type='mmcls.ResizeEdge',
|
||||
scale=256,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bilinear'),
|
||||
dict(type='mmcls.CenterCrop', crop_size=224),
|
||||
dict(type='mmcls.PackClsInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=16,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='mmcls.RepeatAugSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=16,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='mmcls.DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='mmcls.Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
||||
|
||||
# optimizer
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type='SGD', lr=0.8, momentum=0.9, weight_decay=0.00001, nesterov=True),
|
||||
paramwise_cfg=dict(bias_decay_mult=0., norm_decay_mult=0.))
|
||||
|
||||
# learning policy
|
||||
max_epochs = 360
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0,
|
||||
end=3125),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=max_epochs,
|
||||
eta_min=0,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=max_epochs,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# train, val, test setting
|
||||
train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)
|
||||
val_cfg = dict(type='mmrazor.SubnetValLoop', calibrate_sample_num=4096)
|
||||
test_cfg = dict(type='mmrazor.SubnetValLoop', calibrate_sample_num=4096)
|
|
@ -43,6 +43,4 @@ model_wrapper_cfg = dict(
|
|||
broadcast_buffers=False,
|
||||
find_unused_parameters=True)
|
||||
|
||||
optim_wrapper = dict(accumulative_counts=3)
|
||||
|
||||
val_cfg = dict(type='mmrazor.SlimmableValLoop')
|
||||
|
|
|
@ -21,12 +21,12 @@ data_preprocessor = dict(
|
|||
)
|
||||
|
||||
# !autoslim algorithm config
|
||||
num_samples = 2
|
||||
num_random_samples = 2
|
||||
model = dict(
|
||||
_delete_=True,
|
||||
_scope_='mmrazor',
|
||||
type='AutoSlim',
|
||||
num_samples=num_samples,
|
||||
num_random_samples=num_random_samples,
|
||||
architecture=supernet,
|
||||
data_preprocessor=data_preprocessor,
|
||||
distiller=dict(
|
||||
|
@ -59,8 +59,6 @@ model_wrapper_cfg = dict(
|
|||
broadcast_buffers=False,
|
||||
find_unused_parameters=False)
|
||||
|
||||
optim_wrapper = dict(accumulative_counts=num_samples + 2)
|
||||
|
||||
# learning policy
|
||||
max_epochs = 50
|
||||
param_scheduler = dict(end=max_epochs)
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
backbone.first_channels:
|
||||
chosen: 16
|
||||
backbone.last_channels:
|
||||
chosen: 1792
|
||||
backbone.layers.1.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.1.expand_ratio:
|
||||
chosen: 1
|
||||
backbone.layers.1.depth:
|
||||
chosen: 1
|
||||
backbone.layers.1.out_channels:
|
||||
chosen: 16
|
||||
backbone.layers.2.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.2.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.2.depth:
|
||||
chosen: 3
|
||||
backbone.layers.2.out_channels:
|
||||
chosen: 24
|
||||
backbone.layers.3.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.3.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.3.depth:
|
||||
chosen: 3
|
||||
backbone.layers.3.out_channels:
|
||||
chosen: 32
|
||||
backbone.layers.4.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.4.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.4.depth:
|
||||
chosen: 3
|
||||
backbone.layers.4.out_channels:
|
||||
chosen: 64
|
||||
backbone.layers.5.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.5.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.5.depth:
|
||||
chosen: 3
|
||||
backbone.layers.5.out_channels:
|
||||
chosen: 112
|
||||
backbone.layers.6.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.6.expand_ratio:
|
||||
chosen: 6
|
||||
backbone.layers.6.depth:
|
||||
chosen: 3
|
||||
backbone.layers.6.out_channels:
|
||||
chosen: 192
|
||||
backbone.layers.7.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.7.expand_ratio:
|
||||
chosen: 6
|
||||
backbone.layers.7.depth:
|
||||
chosen: 1
|
||||
backbone.layers.7.out_channels:
|
||||
chosen: 216
|
||||
input_shape:
|
||||
chosen:
|
||||
- 192
|
||||
- 192
|
|
@ -0,0 +1,64 @@
|
|||
backbone.first_channels:
|
||||
chosen: 16
|
||||
backbone.last_channels:
|
||||
chosen: 1984
|
||||
backbone.layers.1.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.1.expand_ratio:
|
||||
chosen: 1
|
||||
backbone.layers.1.depth:
|
||||
chosen: 1
|
||||
backbone.layers.1.out_channels:
|
||||
chosen: 16
|
||||
backbone.layers.2.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.2.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.2.depth:
|
||||
chosen: 3
|
||||
backbone.layers.2.out_channels:
|
||||
chosen: 24
|
||||
backbone.layers.3.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.3.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.3.depth:
|
||||
chosen: 3
|
||||
backbone.layers.3.out_channels:
|
||||
chosen: 32
|
||||
backbone.layers.4.kernel_size:
|
||||
chosen: 5
|
||||
backbone.layers.4.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.4.depth:
|
||||
chosen: 3
|
||||
backbone.layers.4.out_channels:
|
||||
chosen: 64
|
||||
backbone.layers.5.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.5.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.5.depth:
|
||||
chosen: 3
|
||||
backbone.layers.5.out_channels:
|
||||
chosen: 112
|
||||
backbone.layers.6.kernel_size:
|
||||
chosen: 5
|
||||
backbone.layers.6.expand_ratio:
|
||||
chosen: 6
|
||||
backbone.layers.6.depth:
|
||||
chosen: 3
|
||||
backbone.layers.6.out_channels:
|
||||
chosen: 192
|
||||
backbone.layers.7.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.7.expand_ratio:
|
||||
chosen: 6
|
||||
backbone.layers.7.depth:
|
||||
chosen: 1
|
||||
backbone.layers.7.out_channels:
|
||||
chosen: 216
|
||||
input_shape:
|
||||
chosen:
|
||||
- 224
|
||||
- 224
|
|
@ -0,0 +1,64 @@
|
|||
backbone.first_channels:
|
||||
chosen: 16
|
||||
backbone.last_channels:
|
||||
chosen: 1984
|
||||
backbone.layers.1.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.1.expand_ratio:
|
||||
chosen: 1
|
||||
backbone.layers.1.depth:
|
||||
chosen: 1
|
||||
backbone.layers.1.out_channels:
|
||||
chosen: 16
|
||||
backbone.layers.2.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.2.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.2.depth:
|
||||
chosen: 3
|
||||
backbone.layers.2.out_channels:
|
||||
chosen: 24
|
||||
backbone.layers.3.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.3.expand_ratio:
|
||||
chosen: 5
|
||||
backbone.layers.3.depth:
|
||||
chosen: 3
|
||||
backbone.layers.3.out_channels:
|
||||
chosen: 32
|
||||
backbone.layers.4.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.4.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.4.depth:
|
||||
chosen: 3
|
||||
backbone.layers.4.out_channels:
|
||||
chosen: 64
|
||||
backbone.layers.5.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.5.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.5.depth:
|
||||
chosen: 3
|
||||
backbone.layers.5.out_channels:
|
||||
chosen: 112
|
||||
backbone.layers.6.kernel_size:
|
||||
chosen: 5
|
||||
backbone.layers.6.expand_ratio:
|
||||
chosen: 6
|
||||
backbone.layers.6.depth:
|
||||
chosen: 4
|
||||
backbone.layers.6.out_channels:
|
||||
chosen: 200
|
||||
backbone.layers.7.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.7.expand_ratio:
|
||||
chosen: 6
|
||||
backbone.layers.7.depth:
|
||||
chosen: 1
|
||||
backbone.layers.7.out_channels:
|
||||
chosen: 224
|
||||
input_shape:
|
||||
chosen:
|
||||
- 224
|
||||
- 224
|
|
@ -0,0 +1,64 @@
|
|||
backbone.first_channels:
|
||||
chosen: 16
|
||||
backbone.last_channels:
|
||||
chosen: 1984
|
||||
backbone.layers.1.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.1.expand_ratio:
|
||||
chosen: 1
|
||||
backbone.layers.1.depth:
|
||||
chosen: 2
|
||||
backbone.layers.1.out_channels:
|
||||
chosen: 16
|
||||
backbone.layers.2.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.2.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.2.depth:
|
||||
chosen: 3
|
||||
backbone.layers.2.out_channels:
|
||||
chosen: 24
|
||||
backbone.layers.3.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.3.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.3.depth:
|
||||
chosen: 3
|
||||
backbone.layers.3.out_channels:
|
||||
chosen: 32
|
||||
backbone.layers.4.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.4.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.4.depth:
|
||||
chosen: 4
|
||||
backbone.layers.4.out_channels:
|
||||
chosen: 64
|
||||
backbone.layers.5.kernel_size:
|
||||
chosen: 5
|
||||
backbone.layers.5.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.5.depth:
|
||||
chosen: 3
|
||||
backbone.layers.5.out_channels:
|
||||
chosen: 112
|
||||
backbone.layers.6.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.6.expand_ratio:
|
||||
chosen: 6
|
||||
backbone.layers.6.depth:
|
||||
chosen: 5
|
||||
backbone.layers.6.out_channels:
|
||||
chosen: 208
|
||||
backbone.layers.7.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.7.expand_ratio:
|
||||
chosen: 6
|
||||
backbone.layers.7.depth:
|
||||
chosen: 1
|
||||
backbone.layers.7.out_channels:
|
||||
chosen: 224
|
||||
input_shape:
|
||||
chosen:
|
||||
- 224
|
||||
- 224
|
|
@ -0,0 +1,64 @@
|
|||
backbone.first_channels:
|
||||
chosen: 16
|
||||
backbone.last_channels:
|
||||
chosen: 1984
|
||||
backbone.layers.1.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.1.expand_ratio:
|
||||
chosen: 1
|
||||
backbone.layers.1.depth:
|
||||
chosen: 1
|
||||
backbone.layers.1.out_channels:
|
||||
chosen: 16
|
||||
backbone.layers.2.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.2.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.2.depth:
|
||||
chosen: 3
|
||||
backbone.layers.2.out_channels:
|
||||
chosen: 24
|
||||
backbone.layers.3.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.3.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.3.depth:
|
||||
chosen: 3
|
||||
backbone.layers.3.out_channels:
|
||||
chosen: 32
|
||||
backbone.layers.4.kernel_size:
|
||||
chosen: 5
|
||||
backbone.layers.4.expand_ratio:
|
||||
chosen: 5
|
||||
backbone.layers.4.depth:
|
||||
chosen: 4
|
||||
backbone.layers.4.out_channels:
|
||||
chosen: 64
|
||||
backbone.layers.5.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.5.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.5.depth:
|
||||
chosen: 3
|
||||
backbone.layers.5.out_channels:
|
||||
chosen: 112
|
||||
backbone.layers.6.kernel_size:
|
||||
chosen: 5
|
||||
backbone.layers.6.expand_ratio:
|
||||
chosen: 6
|
||||
backbone.layers.6.depth:
|
||||
chosen: 5
|
||||
backbone.layers.6.out_channels:
|
||||
chosen: 192
|
||||
backbone.layers.7.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.7.expand_ratio:
|
||||
chosen: 6
|
||||
backbone.layers.7.depth:
|
||||
chosen: 1
|
||||
backbone.layers.7.out_channels:
|
||||
chosen: 216
|
||||
input_shape:
|
||||
chosen:
|
||||
- 256
|
||||
- 256
|
|
@ -0,0 +1,64 @@
|
|||
backbone.first_channels:
|
||||
chosen: 16
|
||||
backbone.last_channels:
|
||||
chosen: 1792
|
||||
backbone.layers.1.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.1.expand_ratio:
|
||||
chosen: 1
|
||||
backbone.layers.1.depth:
|
||||
chosen: 1
|
||||
backbone.layers.1.out_channels:
|
||||
chosen: 16
|
||||
backbone.layers.2.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.2.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.2.depth:
|
||||
chosen: 3
|
||||
backbone.layers.2.out_channels:
|
||||
chosen: 24
|
||||
backbone.layers.3.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.3.expand_ratio:
|
||||
chosen: 5
|
||||
backbone.layers.3.depth:
|
||||
chosen: 3
|
||||
backbone.layers.3.out_channels:
|
||||
chosen: 32
|
||||
backbone.layers.4.kernel_size:
|
||||
chosen: 5
|
||||
backbone.layers.4.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.4.depth:
|
||||
chosen: 3
|
||||
backbone.layers.4.out_channels:
|
||||
chosen: 64
|
||||
backbone.layers.5.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.5.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.5.depth:
|
||||
chosen: 4
|
||||
backbone.layers.5.out_channels:
|
||||
chosen: 112
|
||||
backbone.layers.6.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.6.expand_ratio:
|
||||
chosen: 6
|
||||
backbone.layers.6.depth:
|
||||
chosen: 6
|
||||
backbone.layers.6.out_channels:
|
||||
chosen: 192
|
||||
backbone.layers.7.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.7.expand_ratio:
|
||||
chosen: 6
|
||||
backbone.layers.7.depth:
|
||||
chosen: 1
|
||||
backbone.layers.7.out_channels:
|
||||
chosen: 224
|
||||
input_shape:
|
||||
chosen:
|
||||
- 256
|
||||
- 256
|
|
@ -0,0 +1,64 @@
|
|||
backbone.first_channels:
|
||||
chosen: 16
|
||||
backbone.last_channels:
|
||||
chosen: 1984
|
||||
backbone.layers.1.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.1.depth:
|
||||
chosen: 1
|
||||
backbone.layers.1.expand_ratio:
|
||||
chosen: 1
|
||||
backbone.layers.1.out_channels:
|
||||
chosen: 24
|
||||
backbone.layers.2.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.2.depth:
|
||||
chosen: 3
|
||||
backbone.layers.2.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.2.out_channels:
|
||||
chosen: 32
|
||||
backbone.layers.3.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.3.depth:
|
||||
chosen: 3
|
||||
backbone.layers.3.expand_ratio:
|
||||
chosen: 6
|
||||
backbone.layers.3.out_channels:
|
||||
chosen: 40
|
||||
backbone.layers.4.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.4.depth:
|
||||
chosen: 4
|
||||
backbone.layers.4.expand_ratio:
|
||||
chosen: 5
|
||||
backbone.layers.4.out_channels:
|
||||
chosen: 72
|
||||
backbone.layers.5.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.5.depth:
|
||||
chosen: 4
|
||||
backbone.layers.5.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.5.out_channels:
|
||||
chosen: 128
|
||||
backbone.layers.6.kernel_size:
|
||||
chosen: 5
|
||||
backbone.layers.6.depth:
|
||||
chosen: 6
|
||||
backbone.layers.6.expand_ratio:
|
||||
chosen: 6
|
||||
backbone.layers.6.out_channels:
|
||||
chosen: 216
|
||||
backbone.layers.7.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.7.depth:
|
||||
chosen: 1
|
||||
backbone.layers.7.expand_ratio:
|
||||
chosen: 6
|
||||
backbone.layers.7.out_channels:
|
||||
chosen: 224
|
||||
input_shape:
|
||||
chosen:
|
||||
- 288
|
||||
- 288
|
|
@ -39,17 +39,20 @@ sh tools/slurm_test.sh $PARTITION $JOB_NAME \
|
|||
|
||||
| Dataset | Supernet | Subnet | Params(M) | Flops(G) | Top-1 | Config | Download | Remarks |
|
||||
| :------: | :------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------: | :------------------: | :---------------------: | :---------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :---------------------: |
|
||||
| ImageNet | AttentiveMobileNetV3 | [mutable](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f_mutable_cfg.yaml) | 8.9(min) / 23.3(max) | 203(min) / 1939(max) | 77.29(min) / 81.65(max) | [config](./detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco.py) | [pretrain](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_shufflenetv2_8xb128_in1k_acc-74.08_20211223-92e9b66a.pth) \|[model](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f.pth) \| [log](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f.log.json) | MMRazor searched |
|
||||
| ImageNet | AttentiveMobileNetV3 | [AttentiveNAS-A0\*](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f_mutable_cfg.yaml) | 11.559 | 414 | 77.01 | [config](./detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco.py) | [pretrain](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_shufflenetv2_8xb128_in1k_acc-74.08_20211223-92e9b66a.pth) \|[model](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f.pth) \| [log](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f.log.json) | Converted from the repo |
|
||||
| ImageNet | AttentiveMobileNetV3 | [AttentiveNAS-A6\*](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f_mutable_cfg.yaml) | 16.476 | 1163 | 80.12 | [config](./detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco.py) | [pretrain](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_shufflenetv2_8xb128_in1k_acc-74.08_20211223-92e9b66a.pth) \|[model](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f.pth) \| [log](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f.log.json) | Converted from the repo |
|
||||
| ImageNet | AttentiveMobileNetV3 | [mutable](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f_mutable_cfg.yaml) | 8.9(min) / 23.3(max) | 203(min) / 1939(max) | 77.25(min) / 81.72(max) | [config](./detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco.py) | [pretrain](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_shufflenetv2_8xb128_in1k_acc-74.08_20211223-92e9b66a.pth) \|[model](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f.pth) \| [log](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f.log.json) | MMRazor searched |
|
||||
| ImageNet | AttentiveMobileNetV3 | [AttentiveNAS-A0\*](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f_mutable_cfg.yaml) | 11.559 | 414 | 77.252 | [config](./detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco.py) | [pretrain](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_shufflenetv2_8xb128_in1k_acc-74.08_20211223-92e9b66a.pth) \|[model](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f.pth) \| [log](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f.log.json) | Converted from the repo |
|
||||
| ImageNet | AttentiveMobileNetV3 | [AttentiveNAS-A6\*](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f_mutable_cfg.yaml) | 16.476 | 1163 | 80.790 | [config](./detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco.py) | [pretrain](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_shufflenetv2_8xb128_in1k_acc-74.08_20211223-92e9b66a.pth) \|[model](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f.pth) \| [log](https://download.openmmlab.com/mmrazor/v0.1/nas/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20211222-67fea61f.log.json) | Converted from the repo |
|
||||
|
||||
*Models with * are converted from the [official repo](https://github.com/facebookresearch/AttentiveNAS). The config files of these models
|
||||
are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
|
||||
|
||||
**Note**:
|
||||
**Note**: In the official `AttentiveNAS` code, the `AutoAugmentation` in Calib-BN subnet recommended to use large batchsize to evaluation like `256`, which leads to higher performance. Compared with the original configuration file, this configuration has been modified as follows:
|
||||
|
||||
- modified the settings related to `batchsize` in `train_pipeline` and `test_pipeline`, e.g. setting `train_dataloader.batch_size=256`、 `val_dataloader.batch_size=256`、`test_cfg.calibrate_sample_num=16384` and `collate_fn=dict(type='default_collate')` in train_dataloader.
|
||||
- setting `dict(type='mmrazor.AutoAugment', policies='original')` instead of `dict(type='mmrazor.AutoAugmentV2', policies=policies)` in train_pipeline.
|
||||
|
||||
1. Used search_space in AttentiveNAS, which is different from BigNAS paper.
|
||||
2. The Top-1 Acc is unstable and may fluctuate by about 0.3, convert the official weight according to the [converter script](../../../../tools/model_converters/convert_attentivenas_nas_ckpt.py). A Calib-BN model will be released later.
|
||||
2. The Top-1 Acc is unstable and may fluctuate by about 0.1, convert the official weight according to the [converter script](../../../../tools/model_converters/convert_attentivenas_nas_ckpt.py). A Calib-BN model will be released later.
|
||||
3. We have observed that the searchable model has been officially released. Although the subnet accuracy has decreased, it is more efficient. We will also provide the supernet training configuration in the future.
|
||||
|
||||
## Citation
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
_base_ = 'attentive_mobilenet_supernet_32xb64_in1k.py'
|
||||
|
||||
model = dict(fix_subnet='configs/nas/mmcls/bignas/subnet_a6.yaml')
|
||||
model = dict(fix_subnet='configs/nas/mmcls/bignas/ATTENTIVE_SUBNET_A6.yaml')
|
||||
|
||||
test_cfg = dict(evaluate_fixed_subnet=True)
|
||||
|
|
|
@ -20,15 +20,15 @@ supernet = dict(
|
|||
mode='original',
|
||||
loss_weight=1.0),
|
||||
topk=(1, 5)),
|
||||
input_resizer_cfg=_base_._INPUT_MUTABLE,
|
||||
connect_head=dict(connect_with_backbone='backbone.last_mutable'),
|
||||
input_resizer_cfg=_base_.input_resizer_cfg,
|
||||
connect_head=dict(connect_with_backbone='backbone.last_mutable_channels'),
|
||||
)
|
||||
|
||||
model = dict(
|
||||
_scope_='mmrazor',
|
||||
type='BigNAS',
|
||||
strategy='sandwich4',
|
||||
drop_path_rate=0.2,
|
||||
num_random_samples=2,
|
||||
backbone_dropout_stages=[6, 7],
|
||||
architecture=supernet,
|
||||
data_preprocessor=_base_.data_preprocessor,
|
||||
|
@ -56,8 +56,6 @@ model = dict(
|
|||
parse_cfg={'type': 'Predefined'}),
|
||||
value_mutator=dict(type='DynamicValueMutator')))
|
||||
|
||||
optim_wrapper = dict(accumulative_counts=4)
|
||||
|
||||
model_wrapper_cfg = dict(
|
||||
type='mmrazor.BigNASDDP',
|
||||
broadcast_buffers=False,
|
||||
|
|
|
@ -1,64 +0,0 @@
|
|||
architecture.backbone.first_mutable:
|
||||
chosen: 16
|
||||
architecture.backbone.last_mutable:
|
||||
chosen: 1792
|
||||
architecture.backbone.layer1.0.depthwise_conv.conv.mutable_attrs.kernel_size:
|
||||
chosen: 3
|
||||
architecture.backbone.layer1.0.mutable_expand_value:
|
||||
chosen: 1
|
||||
architecture.backbone.layer1.mutable_attrs.depth:
|
||||
chosen: 1
|
||||
architecture.backbone.layer1.mutable_out_channels:
|
||||
chosen: 16
|
||||
architecture.backbone.layer2.0.depthwise_conv.conv.mutable_attrs.kernel_size:
|
||||
chosen: 3
|
||||
architecture.backbone.layer2.0.mutable_expand_value:
|
||||
chosen: 4
|
||||
architecture.backbone.layer2.mutable_attrs.depth:
|
||||
chosen: 3
|
||||
architecture.backbone.layer2.mutable_out_channels:
|
||||
chosen: 24
|
||||
architecture.backbone.layer3.0.depthwise_conv.conv.mutable_attrs.kernel_size:
|
||||
chosen: 3
|
||||
architecture.backbone.layer3.0.mutable_expand_value:
|
||||
chosen: 4
|
||||
architecture.backbone.layer3.mutable_attrs.depth:
|
||||
chosen: 3
|
||||
architecture.backbone.layer3.mutable_out_channels:
|
||||
chosen: 32
|
||||
architecture.backbone.layer4.0.depthwise_conv.conv.mutable_attrs.kernel_size:
|
||||
chosen: 3
|
||||
architecture.backbone.layer4.0.mutable_expand_value:
|
||||
chosen: 4
|
||||
architecture.backbone.layer4.mutable_attrs.depth:
|
||||
chosen: 3
|
||||
architecture.backbone.layer4.mutable_out_channels:
|
||||
chosen: 64
|
||||
architecture.backbone.layer5.0.depthwise_conv.conv.mutable_attrs.kernel_size:
|
||||
chosen: 3
|
||||
architecture.backbone.layer5.0.mutable_expand_value:
|
||||
chosen: 4
|
||||
architecture.backbone.layer5.mutable_attrs.depth:
|
||||
chosen: 3
|
||||
architecture.backbone.layer5.mutable_out_channels:
|
||||
chosen: 112
|
||||
architecture.backbone.layer6.0.depthwise_conv.conv.mutable_attrs.kernel_size:
|
||||
chosen: 3
|
||||
architecture.backbone.layer6.0.mutable_expand_value:
|
||||
chosen: 6
|
||||
architecture.backbone.layer6.mutable_attrs.depth:
|
||||
chosen: 3
|
||||
architecture.backbone.layer6.mutable_out_channels:
|
||||
chosen: 192
|
||||
architecture.backbone.layer7.0.depthwise_conv.conv.mutable_attrs.kernel_size:
|
||||
chosen: 3
|
||||
architecture.backbone.layer7.0.mutable_expand_value:
|
||||
chosen: 6
|
||||
architecture.backbone.layer7.mutable_attrs.depth:
|
||||
chosen: 1
|
||||
architecture.backbone.layer7.mutable_out_channels:
|
||||
chosen: 216
|
||||
architecture.input_resizer.mutable_attrs.shape:
|
||||
chosen:
|
||||
- 192
|
||||
- 192
|
|
@ -1,64 +0,0 @@
|
|||
architecture.backbone.first_mutable:
|
||||
chosen: 24
|
||||
architecture.backbone.last_mutable:
|
||||
chosen: 1984
|
||||
architecture.backbone.layer1.0.depthwise_conv.conv.mutable_attrs.kernel_size:
|
||||
chosen: 5
|
||||
architecture.backbone.layer1.0.mutable_expand_value:
|
||||
chosen: 1
|
||||
architecture.backbone.layer1.mutable_attrs.depth:
|
||||
chosen: 2
|
||||
architecture.backbone.layer1.mutable_out_channels:
|
||||
chosen: 24
|
||||
architecture.backbone.layer2.0.depthwise_conv.conv.mutable_attrs.kernel_size:
|
||||
chosen: 5
|
||||
architecture.backbone.layer2.0.mutable_expand_value:
|
||||
chosen: 6
|
||||
architecture.backbone.layer2.mutable_attrs.depth:
|
||||
chosen: 5
|
||||
architecture.backbone.layer2.mutable_out_channels:
|
||||
chosen: 32
|
||||
architecture.backbone.layer3.0.depthwise_conv.conv.mutable_attrs.kernel_size:
|
||||
chosen: 5
|
||||
architecture.backbone.layer3.0.mutable_expand_value:
|
||||
chosen: 6
|
||||
architecture.backbone.layer3.mutable_attrs.depth:
|
||||
chosen: 6
|
||||
architecture.backbone.layer3.mutable_out_channels:
|
||||
chosen: 40
|
||||
architecture.backbone.layer4.0.depthwise_conv.conv.mutable_attrs.kernel_size:
|
||||
chosen: 5
|
||||
architecture.backbone.layer4.0.mutable_expand_value:
|
||||
chosen: 6
|
||||
architecture.backbone.layer4.mutable_attrs.depth:
|
||||
chosen: 6
|
||||
architecture.backbone.layer4.mutable_out_channels:
|
||||
chosen: 72
|
||||
architecture.backbone.layer5.0.depthwise_conv.conv.mutable_attrs.kernel_size:
|
||||
chosen: 5
|
||||
architecture.backbone.layer5.0.mutable_expand_value:
|
||||
chosen: 6
|
||||
architecture.backbone.layer5.mutable_attrs.depth:
|
||||
chosen: 8
|
||||
architecture.backbone.layer5.mutable_out_channels:
|
||||
chosen: 128
|
||||
architecture.backbone.layer6.0.depthwise_conv.conv.mutable_attrs.kernel_size:
|
||||
chosen: 5
|
||||
architecture.backbone.layer6.0.mutable_expand_value:
|
||||
chosen: 6
|
||||
architecture.backbone.layer6.mutable_attrs.depth:
|
||||
chosen: 8
|
||||
architecture.backbone.layer6.mutable_out_channels:
|
||||
chosen: 216
|
||||
architecture.backbone.layer7.0.depthwise_conv.conv.mutable_attrs.kernel_size:
|
||||
chosen: 5
|
||||
architecture.backbone.layer7.0.mutable_expand_value:
|
||||
chosen: 6
|
||||
architecture.backbone.layer7.mutable_attrs.depth:
|
||||
chosen: 2
|
||||
architecture.backbone.layer7.mutable_out_channels:
|
||||
chosen: 224
|
||||
architecture.input_resizer.mutable_attrs.shape:
|
||||
chosen:
|
||||
- 288
|
||||
- 288
|
|
@ -0,0 +1,116 @@
|
|||
backbone.first_channels:
|
||||
chosen: 16
|
||||
backbone.last_channels:
|
||||
chosen: 1280
|
||||
backbone.layers.1.0.expand_ratio:
|
||||
chosen: 1
|
||||
backbone.layers.1.0.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.1.depth:
|
||||
chosen: 1
|
||||
backbone.layers.1.out_channels:
|
||||
chosen: 16
|
||||
backbone.layers.2.0.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.2.0.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.2.1.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.2.1.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.2.2.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.2.2.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.2.3.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.2.3.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.2.depth:
|
||||
chosen: 2
|
||||
backbone.layers.2.out_channels:
|
||||
chosen: 24
|
||||
backbone.layers.3.0.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.3.0.kernel_size:
|
||||
chosen: 5
|
||||
backbone.layers.3.1.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.3.1.kernel_size:
|
||||
chosen: 5
|
||||
backbone.layers.3.2.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.3.2.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.3.3.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.3.3.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.3.depth:
|
||||
chosen: 2
|
||||
backbone.layers.3.out_channels:
|
||||
chosen: 40
|
||||
backbone.layers.4.0.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.4.0.kernel_size:
|
||||
chosen: 7
|
||||
backbone.layers.4.1.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.4.1.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.4.2.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.4.2.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.4.3.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.4.3.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.4.depth:
|
||||
chosen: 2
|
||||
backbone.layers.4.out_channels:
|
||||
chosen: 80
|
||||
backbone.layers.5.0.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.5.0.kernel_size:
|
||||
chosen: 5
|
||||
backbone.layers.5.1.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.5.1.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.5.2.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.5.2.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.5.3.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.5.3.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.5.depth:
|
||||
chosen: 2
|
||||
backbone.layers.5.out_channels:
|
||||
chosen: 112
|
||||
backbone.layers.6.0.expand_ratio:
|
||||
chosen: 6
|
||||
backbone.layers.6.0.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.6.1.expand_ratio:
|
||||
chosen: 6
|
||||
backbone.layers.6.1.kernel_size:
|
||||
chosen: 7
|
||||
backbone.layers.6.2.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.6.2.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.6.3.expand_ratio:
|
||||
chosen: 6
|
||||
backbone.layers.6.3.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.6.depth:
|
||||
chosen: 2
|
||||
backbone.layers.6.out_channels:
|
||||
chosen: 160
|
||||
input_shape:
|
||||
chosen:
|
||||
- 140
|
||||
- 140
|
|
@ -0,0 +1,116 @@
|
|||
backbone.first_channels:
|
||||
chosen: 16
|
||||
backbone.last_channels:
|
||||
chosen: 1280
|
||||
backbone.layers.1.0.expand_ratio:
|
||||
chosen: 1
|
||||
backbone.layers.1.0.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.1.depth:
|
||||
chosen: 1
|
||||
backbone.layers.1.out_channels:
|
||||
chosen: 16
|
||||
backbone.layers.2.0.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.2.0.kernel_size:
|
||||
chosen: 5
|
||||
backbone.layers.2.1.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.2.1.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.2.2.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.2.2.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.2.3.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.2.3.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.2.depth:
|
||||
chosen: 2
|
||||
backbone.layers.2.out_channels:
|
||||
chosen: 24
|
||||
backbone.layers.3.0.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.3.0.kernel_size:
|
||||
chosen: 5
|
||||
backbone.layers.3.1.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.3.1.kernel_size:
|
||||
chosen: 5
|
||||
backbone.layers.3.2.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.3.2.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.3.3.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.3.3.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.3.depth:
|
||||
chosen: 2
|
||||
backbone.layers.3.out_channels:
|
||||
chosen: 40
|
||||
backbone.layers.4.0.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.4.0.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.4.1.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.4.1.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.4.2.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.4.2.kernel_size:
|
||||
chosen: 5
|
||||
backbone.layers.4.3.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.4.3.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.4.depth:
|
||||
chosen: 3
|
||||
backbone.layers.4.out_channels:
|
||||
chosen: 80
|
||||
backbone.layers.5.0.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.5.0.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.5.1.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.5.1.kernel_size:
|
||||
chosen: 5
|
||||
backbone.layers.5.2.expand_ratio:
|
||||
chosen: 4
|
||||
backbone.layers.5.2.kernel_size:
|
||||
chosen: 7
|
||||
backbone.layers.5.3.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.5.3.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.5.depth:
|
||||
chosen: 3
|
||||
backbone.layers.5.out_channels:
|
||||
chosen: 112
|
||||
backbone.layers.6.0.expand_ratio:
|
||||
chosen: 6
|
||||
backbone.layers.6.0.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.6.1.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.6.1.kernel_size:
|
||||
chosen: 3
|
||||
backbone.layers.6.2.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.6.2.kernel_size:
|
||||
chosen: 5
|
||||
backbone.layers.6.3.expand_ratio:
|
||||
chosen: 3
|
||||
backbone.layers.6.3.kernel_size:
|
||||
chosen: 5
|
||||
backbone.layers.6.depth:
|
||||
chosen: 4
|
||||
backbone.layers.6.out_channels:
|
||||
chosen: 160
|
||||
input_shape:
|
||||
chosen:
|
||||
- 152
|
||||
- 152
|
|
@ -0,0 +1,49 @@
|
|||
# Once-For-All
|
||||
|
||||
> [ONCE-FOR-ALL: TRAIN ONE NETWORK AND SPE- CIALIZE IT FOR EFFICIENT DEPLOYMENT](https://arxiv.org/abs/1908.09791)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
We address the challenging problem of efficient inference across many devices and resource constraints, especially on edge devices. Conventional approaches either manually design or use neural architecture search (NAS) to find a specialized neural network and train it from scratch for each case, which is computationally prohibitive (causing CO2 emission as much as 5 cars’ lifetime Strubell et al. (2019)) thus unscalable. In this work, we propose to train a once-for-all (OFA) network that supports diverse architectural settings by decoupling training and search, to reduce the cost. We can quickly get a specialized sub-network by selecting from the OFA network without additional training. To efficiently train OFA networks, we also propose a novel progressive shrinking algorithm, a generalized pruning method that reduces the model size across many more dimensions than pruning (depth, width, kernel size, and resolution). It can obtain a surprisingly large number of sub- networks (> 1019) that can fit different hardware platforms and latency constraints while maintaining the same level of accuracy as training independently. On diverse edge devices, OFA consistently outperforms state-of-the-art (SOTA) NAS methods (up to 4.0% ImageNet top1 accuracy improvement over MobileNetV3, or same accuracy but 1.5× faster than MobileNetV3, 2.6× faster than EfficientNet w.r.t measured latency) while reducing many orders of magnitude GPU hours and CO2 emission. In particular, OFA achieves a new SOTA 80.0% ImageNet top-1 accuracy under the mobile setting (\<600M MACs). OFA is the winning solution for the 3rd Low Power Computer Vision Challenge (LPCVC), DSP classification track and the 4th LPCVC, both classification track and detection track.
|
||||
|
||||
## Introduction
|
||||
|
||||
We product inference models which are published by official Once-For-All repo and converted by MMRazor.
|
||||
|
||||
### Subnet test on ImageNet
|
||||
|
||||
```bash
|
||||
sh tools/slurm_test.sh $PARTITION $JOB_NAME \
|
||||
configs/nas/mmcls/onceforall/ofa_mobilenet_subnet_8xb256_in1k.py \
|
||||
$STEP2_CKPT --work-dir $WORK_DIR --eval accuracy
|
||||
```
|
||||
|
||||
## Results and models
|
||||
|
||||
| Dataset | Supernet | Subnet | Params(M) | Flops(G) | Top-1 | Config | Download | Remarks |
|
||||
| :------: | :------------------: | :----------------------------------------------------------------------: | :-------: | :------: | :---: | :---------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :---------------------: |
|
||||
| ImageNet | AttentiveMobileNetV3 | [search space](configs/_base_/nas_backbones/ofa_mobilenetv3_supernet.py) | 7.6 | 747.8 | 77.5 | [config](./detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/ofa/ofa_mobilenet_supernet_d234_e346_k357_w1_0.py_20221214_0940-d0ebc66f.pth) | Converted from the repo |
|
||||
| ImageNet | AttentiveMobileNetV3 | note8_lat@22ms_top1@70.4_finetune@25 | 4.3 | 70.9 | 70.3 | [config](./OFA_SUBNET_NOTE8_LAT22.yaml) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/ofa/ofa_mobilenet_subnet_8xb256_in1k_note8_lat%4022ms_top1%4070.4_finetune%4025.py_20221214_0938-fb7fb84f.pth) | Converted from the repo |
|
||||
| ImageNet | AttentiveMobileNetV3 | note8_lat@31ms_top1@72.8_finetune@25 | 4.6 | 105.4 | 72.6 | [config](./OFA_SUBNET_NOTE8_LAT31.yaml) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/ofa/ofa_mobilenet_subnet_8xb256_in1k_note8_lat%4031ms_top1%4072.8_finetune%4025.py_20221214_0939-981a8b2a.pth) | Converted from the repo |
|
||||
|
||||
**Note**:
|
||||
|
||||
1. OFA provides a more fine-grained search mode, which searches expand ratios & kernel size for each block in every layer of the defined supernet, therefore the subnet configs (format as .yaml) is more complex than those of BigNAS/AttentiveNAS.
|
||||
2. We product the [ofa script](../../../../tools/model_converters/convert_ofa_ckpt.py) to convert the official weight into MMRazor-style. The layer depth of a specific subnet is required when converting keys.
|
||||
3. The models above are converted from the [once-for-all official repo](https://github.com/mit-han-lab/once-for-all). The config files of these models
|
||||
are only for inference. We don't ensure training accuracy of these config files and you are welcomed to contribute your reproduction results.
|
||||
|
||||
## Citation
|
||||
|
||||
```latex
|
||||
@inproceedings{
|
||||
cai2020once,
|
||||
title={Once for All: Train One Network and Specialize it for Efficient Deployment},
|
||||
author={Han Cai and Chuang Gan and Tianzhe Wang and Zhekai Zhang and Song Han},
|
||||
booktitle={International Conference on Learning Representations},
|
||||
year={2020},
|
||||
url={https://arxiv.org/pdf/1908.09791.pdf}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,16 @@
|
|||
_base_ = ['./ofa_mobilenet_supernet_32xb64_in1k.py']
|
||||
|
||||
train_cfg = dict(
|
||||
_delete_=True,
|
||||
type='mmrazor.EvolutionSearchLoop',
|
||||
dataloader=_base_.val_dataloader,
|
||||
evaluator=_base_.val_evaluator,
|
||||
max_epochs=1,
|
||||
num_candidates=2,
|
||||
top_k=1,
|
||||
num_mutation=1,
|
||||
num_crossover=1,
|
||||
mutate_prob=0.1,
|
||||
calibrate_sample_num=4096,
|
||||
constraints_range=dict(flops=(0., 700.)),
|
||||
score_key='accuracy/top1')
|
|
@ -0,0 +1,6 @@
|
|||
_base_ = 'ofa_mobilenet_supernet_32xb64_in1k.py'
|
||||
|
||||
model = dict(
|
||||
fix_subnet='configs/nas/mmcls/onceforall/OFA_SUBNET_NOTE8_LAT22.yaml')
|
||||
|
||||
test_cfg = dict(evaluate_fixed_subnet=True)
|
|
@ -0,0 +1,61 @@
|
|||
_base_ = [
|
||||
'mmcls::_base_/default_runtime.py',
|
||||
'mmrazor::_base_/settings/imagenet_bs2048_ofa.py',
|
||||
'mmrazor::_base_/nas_backbones/ofa_mobilenetv3_supernet.py',
|
||||
]
|
||||
|
||||
supernet = dict(
|
||||
_scope_='mmrazor',
|
||||
type='SearchableImageClassifier',
|
||||
backbone=_base_.nas_backbone,
|
||||
neck=dict(type='mmcls.GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='DynamicLinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1280,
|
||||
loss=dict(
|
||||
type='mmcls.LabelSmoothLoss',
|
||||
num_classes=1000,
|
||||
label_smooth_val=0.1,
|
||||
mode='original',
|
||||
loss_weight=1.0),
|
||||
topk=(1, 5)),
|
||||
input_resizer_cfg=_base_.input_resizer_cfg,
|
||||
connect_head=dict(connect_with_backbone='backbone.last_mutable_channels'),
|
||||
)
|
||||
|
||||
model = dict(
|
||||
_scope_='mmrazor',
|
||||
type='BigNAS',
|
||||
drop_path_rate=0.2,
|
||||
backbone_dropout_stages=[6, 7],
|
||||
architecture=supernet,
|
||||
data_preprocessor=_base_.data_preprocessor,
|
||||
distiller=dict(
|
||||
type='ConfigurableDistiller',
|
||||
teacher_recorders=dict(
|
||||
fc=dict(type='ModuleOutputs', source='head.fc')),
|
||||
student_recorders=dict(
|
||||
fc=dict(type='ModuleOutputs', source='head.fc')),
|
||||
distill_losses=dict(
|
||||
loss_kl=dict(type='KLDivergence', tau=1, loss_weight=1)),
|
||||
loss_forward_mappings=dict(
|
||||
loss_kl=dict(
|
||||
preds_S=dict(recorder='fc', from_student=True),
|
||||
preds_T=dict(recorder='fc', from_student=False)))),
|
||||
mutators=dict(
|
||||
channel_mutator=dict(
|
||||
type='mmrazor.OneShotChannelMutator',
|
||||
channel_unit_cfg={
|
||||
'type': 'OneShotMutableChannelUnit',
|
||||
'default_args': {
|
||||
'unit_predefined': True
|
||||
}
|
||||
},
|
||||
parse_cfg={'type': 'Predefined'}),
|
||||
value_mutator=dict(type='DynamicValueMutator')))
|
||||
|
||||
model_wrapper_cfg = dict(
|
||||
type='mmrazor.BigNASDDP',
|
||||
broadcast_buffers=False,
|
||||
find_unused_parameters=True)
|
|
@ -47,8 +47,8 @@ class AutoSlim(BaseAlgorithm):
|
|||
distiller: VALID_DISTILLER_TYPE,
|
||||
architecture: Union[BaseModel, Dict],
|
||||
data_preprocessor: Optional[Union[Dict, nn.Module]] = None,
|
||||
init_cfg: Optional[Dict] = None,
|
||||
num_samples: int = 2) -> None:
|
||||
num_random_samples: int = 2,
|
||||
init_cfg: Optional[Dict] = None) -> None:
|
||||
super().__init__(architecture, data_preprocessor, init_cfg)
|
||||
self.mutator = self._build_mutator(mutator)
|
||||
# `prepare_from_supernet` must be called before distiller initialized
|
||||
|
@ -107,23 +107,23 @@ class AutoSlim(BaseAlgorithm):
|
|||
batch_inputs, data_samples = self.data_preprocessor(data, True)
|
||||
|
||||
total_losses = dict()
|
||||
# update the max subnet loss.
|
||||
self.set_max_subnet()
|
||||
......
|
||||
total_losses.update(add_prefix(max_subnet_losses, 'max_subnet'))
|
||||
|
||||
# update the min subnet loss.
|
||||
self.set_min_subnet()
|
||||
min_subnet_losses = distill_step(batch_inputs, data_samples)
|
||||
total_losses.update(add_prefix(min_subnet_losses, 'min_subnet'))
|
||||
|
||||
# update the random subnet loss.
|
||||
for sample_idx in range(self.num_samples):
|
||||
self.set_subnet(self.sample_subnet())
|
||||
random_subnet_losses = distill_step(batch_inputs, data_samples)
|
||||
total_losses.update(
|
||||
add_prefix(random_subnet_losses,
|
||||
f'random_subnet_{sample_idx}'))
|
||||
for kind in self.sample_kinds:
|
||||
# update the max subnet loss.
|
||||
if kind == 'max':
|
||||
self.set_max_subnet()
|
||||
......
|
||||
total_losses.update(add_prefix(max_subnet_losses, 'max_subnet'))
|
||||
# update the min subnet loss.
|
||||
elif kind == 'min':
|
||||
self.set_min_subnet()
|
||||
min_subnet_losses = distill_step(batch_inputs, data_samples)
|
||||
total_losses.update(add_prefix(min_subnet_losses, 'min_subnet'))
|
||||
# update the random subnets loss.
|
||||
elif 'random' in kind:
|
||||
self.set_subnet(self.sample_subnet())
|
||||
random_subnet_losses = distill_step(batch_inputs, data_samples)
|
||||
total_losses.update(
|
||||
add_prefix(random_subnet_losses, f'{kind}_subnet'))
|
||||
|
||||
return total_losses
|
||||
```
|
||||
|
|
|
@ -17,8 +17,8 @@ class AutoSlim(BaseAlgorithm):
|
|||
distiller,
|
||||
architecture,
|
||||
data_preprocessor,
|
||||
init_cfg = None,
|
||||
num_samples = 2) -> None:
|
||||
num_random_samples = 2,
|
||||
init_cfg = None) -> None:
|
||||
super().__init__(**kwargs)
|
||||
pass
|
||||
|
||||
|
@ -78,8 +78,8 @@ class AutoSlim(BaseAlgorithm):
|
|||
distiller,
|
||||
architecture,
|
||||
data_preprocessor,
|
||||
init_cfg = None,
|
||||
num_samples = 2) -> None:
|
||||
num_random_samples = 2,
|
||||
init_cfg = None) -> None:
|
||||
super(AutoSlim, self).__init__(**kwargs)
|
||||
pass
|
||||
|
||||
|
@ -95,24 +95,27 @@ class AutoSlim(BaseAlgorithm):
|
|||
batch_inputs, data_samples = self.data_preprocessor(data, True)
|
||||
|
||||
total_losses = dict()
|
||||
self.set_max_subnet()
|
||||
with optim_wrapper.optim_context(
|
||||
self), self.distiller.teacher_recorders: # type: ignore
|
||||
max_subnet_losses = self(batch_inputs, data_samples, mode='loss')
|
||||
parsed_max_subnet_losses, _ = self.parse_losses(max_subnet_losses)
|
||||
optim_wrapper.update_params(parsed_max_subnet_losses)
|
||||
total_losses.update(add_prefix(max_subnet_losses, 'max_subnet'))
|
||||
|
||||
self.set_min_subnet()
|
||||
min_subnet_losses = distill_step(batch_inputs, data_samples)
|
||||
total_losses.update(add_prefix(min_subnet_losses, 'min_subnet'))
|
||||
|
||||
for sample_idx in range(self.num_samples):
|
||||
self.set_subnet(self.sample_subnet())
|
||||
random_subnet_losses = distill_step(batch_inputs, data_samples)
|
||||
total_losses.update(
|
||||
add_prefix(random_subnet_losses,
|
||||
f'random_subnet_{sample_idx}'))
|
||||
for kind in self.sample_kinds:
|
||||
# update the max subnet loss.
|
||||
if kind == 'max':
|
||||
self.set_max_subnet()
|
||||
with optim_wrapper.optim_context(
|
||||
self), self.distiller.teacher_recorders: # type: ignore
|
||||
max_subnet_losses = self(batch_inputs, data_samples, mode='loss')
|
||||
parsed_max_subnet_losses, _ = self.parse_losses(max_subnet_losses)
|
||||
optim_wrapper.update_params(parsed_max_subnet_losses)
|
||||
total_losses.update(add_prefix(max_subnet_losses, 'max_subnet'))
|
||||
# update the min subnet loss.
|
||||
elif kind == 'min':
|
||||
self.set_min_subnet()
|
||||
min_subnet_losses = distill_step(batch_inputs, data_samples)
|
||||
total_losses.update(add_prefix(min_subnet_losses, 'min_subnet'))
|
||||
# update the random subnets loss.
|
||||
elif 'random' in kind:
|
||||
self.set_subnet(self.sample_subnet())
|
||||
random_subnet_losses = distill_step(batch_inputs, data_samples)
|
||||
total_losses.update(
|
||||
add_prefix(random_subnet_losses, f'{kind}_subnet'))
|
||||
|
||||
return total_losses
|
||||
```
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .crd_dataset_wrapper import CRDDataset
|
||||
from .transforms import AutoAugmentV2, PackCRDClsInputs
|
||||
from .transforms import AutoAugment, AutoAugmentV2, PackCRDClsInputs
|
||||
|
||||
__all__ = ['AutoAugmentV2', 'PackCRDClsInputs', 'CRDDataset']
|
||||
__all__ = ['AutoAugment', 'AutoAugmentV2', 'PackCRDClsInputs', 'CRDDataset']
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .auto_augment import AutoAugment
|
||||
from .auto_augmentv2 import AutoAugmentV2
|
||||
from .formatting import PackCRDClsInputs
|
||||
|
||||
__all__ = ['AutoAugmentV2', 'PackCRDClsInputs']
|
||||
__all__ = ['AutoAugment', 'AutoAugmentV2', 'PackCRDClsInputs']
|
||||
|
|
|
@ -0,0 +1,415 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from mmcv.transforms import BaseTransform
|
||||
from PIL import Image, ImageEnhance, ImageOps
|
||||
|
||||
from mmrazor.registry import TRANSFORMS
|
||||
|
||||
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
|
||||
|
||||
_FILL = (128, 128, 128)
|
||||
|
||||
# This signifies the max integer that the controller RNN could predict for the
|
||||
# augmentation scheme.
|
||||
_MAX_LEVEL = 10.
|
||||
|
||||
_HPARAMS_DEFAULT = dict(
|
||||
translate_const=250,
|
||||
img_mean=_FILL,
|
||||
)
|
||||
|
||||
_RANDOM_INTERPOLATION = (Image.NEAREST, Image.BILINEAR, Image.BICUBIC)
|
||||
|
||||
|
||||
def _interpolation(kwargs):
|
||||
interpolation = kwargs.pop('resample', Image.NEAREST)
|
||||
if isinstance(interpolation, (list, tuple)):
|
||||
return random.choice(interpolation)
|
||||
else:
|
||||
return interpolation
|
||||
|
||||
|
||||
def _check_args_tf(kwargs):
|
||||
if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
|
||||
kwargs.pop('fillcolor')
|
||||
kwargs['resample'] = _interpolation(kwargs)
|
||||
|
||||
|
||||
def shear_x(img, factor, **kwargs):
|
||||
"""ShearX images."""
|
||||
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0),
|
||||
**kwargs)
|
||||
|
||||
|
||||
def shear_y(img, factor, **kwargs):
|
||||
"""ShearY images."""
|
||||
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0),
|
||||
**kwargs)
|
||||
|
||||
|
||||
def translate_x_rel(img, pct, **kwargs):
|
||||
"""TranslateXRel images."""
|
||||
|
||||
pixels = pct * img.size[0]
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0),
|
||||
**kwargs)
|
||||
|
||||
|
||||
def translate_y_rel(img, pct, **kwargs):
|
||||
"""TranslateYRel images."""
|
||||
|
||||
pixels = pct * img.size[1]
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels),
|
||||
**kwargs)
|
||||
|
||||
|
||||
def translate_x_abs(img, pixels, **kwargs):
|
||||
"""TranslateX images."""
|
||||
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0),
|
||||
**kwargs)
|
||||
|
||||
|
||||
def translate_y_abs(img, pixels, **kwargs):
|
||||
"""TranslateY images."""
|
||||
|
||||
_check_args_tf(kwargs)
|
||||
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels),
|
||||
**kwargs)
|
||||
|
||||
|
||||
def rotate(img, degrees, **kwargs):
|
||||
"""Rotate images."""
|
||||
|
||||
_check_args_tf(kwargs)
|
||||
if _PIL_VER >= (5, 2):
|
||||
return img.rotate(degrees, **kwargs)
|
||||
elif _PIL_VER >= (5, 0):
|
||||
w, h = img.size
|
||||
post_trans = (0, 0)
|
||||
rotn_center = (w / 2.0, h / 2.0)
|
||||
angle = -math.radians(degrees)
|
||||
matrix = [
|
||||
round(math.cos(angle), 15),
|
||||
round(math.sin(angle), 15),
|
||||
0.0,
|
||||
round(-math.sin(angle), 15),
|
||||
round(math.cos(angle), 15),
|
||||
0.0,
|
||||
]
|
||||
|
||||
def transform(x, y, matrix):
|
||||
(a, b, c, d, e, f) = matrix
|
||||
return a * x + b * y + c, d * x + e * y + f
|
||||
|
||||
matrix[2], matrix[5] = transform(-rotn_center[0] - post_trans[0],
|
||||
-rotn_center[1] - post_trans[1],
|
||||
matrix)
|
||||
matrix[2] += rotn_center[0]
|
||||
matrix[5] += rotn_center[1]
|
||||
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
|
||||
else:
|
||||
return img.rotate(degrees, resample=kwargs['resample'])
|
||||
|
||||
|
||||
def auto_contrast(img, **__):
|
||||
"""AutoContrast images."""
|
||||
|
||||
return ImageOps.autocontrast(img)
|
||||
|
||||
|
||||
def invert(img, **__):
|
||||
"""Invert images."""
|
||||
|
||||
return ImageOps.invert(img)
|
||||
|
||||
|
||||
def equalize(img, **__):
|
||||
"""Equalize images."""
|
||||
|
||||
return ImageOps.equalize(img)
|
||||
|
||||
|
||||
def solarize(img, thresh, **__):
|
||||
"""Solarize images."""
|
||||
|
||||
return ImageOps.solarize(img, thresh)
|
||||
|
||||
|
||||
def solarize_add(img, add, thresh=128, **__):
|
||||
"""SolarizeAdd images."""
|
||||
|
||||
lut = []
|
||||
for i in range(256):
|
||||
if i < thresh:
|
||||
lut.append(min(255, i + add))
|
||||
else:
|
||||
lut.append(i)
|
||||
if img.mode in ('L', 'RGB'):
|
||||
if img.mode == 'RGB' and len(lut) == 256:
|
||||
lut = lut + lut + lut
|
||||
return img.point(lut)
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
def posterize(img, bits_to_keep, **__):
|
||||
"""Posterize images."""
|
||||
|
||||
if bits_to_keep >= 8:
|
||||
return img
|
||||
bits_to_keep = max(1, bits_to_keep) # prevent all 0 images
|
||||
return ImageOps.posterize(img, bits_to_keep)
|
||||
|
||||
|
||||
def contrast(img, factor, **__):
|
||||
"""Contrast images."""
|
||||
|
||||
return ImageEnhance.Contrast(img).enhance(factor)
|
||||
|
||||
|
||||
def color(img, factor, **__):
|
||||
"""Color images."""
|
||||
|
||||
return ImageEnhance.Color(img).enhance(factor)
|
||||
|
||||
|
||||
def brightness(img, factor, **__):
|
||||
"""Brightness images."""
|
||||
|
||||
return ImageEnhance.Brightness(img).enhance(factor)
|
||||
|
||||
|
||||
def sharpness(img, factor, **__):
|
||||
"""Sharpness images."""
|
||||
|
||||
return ImageEnhance.Sharpness(img).enhance(factor)
|
||||
|
||||
|
||||
def _randomly_negate(v):
|
||||
"""With 50% prob, negate the value."""
|
||||
return -v if random.random() > 0.5 else v
|
||||
|
||||
|
||||
class AutoAugmentOp(BaseTransform):
|
||||
|
||||
def __init__(self, name, prob, magnitude, hparams={}):
|
||||
NAME_TO_OP = {
|
||||
'AutoContrast': auto_contrast,
|
||||
'Equalize': equalize,
|
||||
'Invert': invert,
|
||||
'Rotate': rotate,
|
||||
'Posterize': posterize,
|
||||
'Posterize2': posterize,
|
||||
'Solarize': solarize,
|
||||
'SolarizeAdd': solarize_add,
|
||||
'Color': color,
|
||||
'Contrast': contrast,
|
||||
'Brightness': brightness,
|
||||
'Sharpness': sharpness,
|
||||
'ShearX': shear_x,
|
||||
'ShearY': shear_y,
|
||||
'TranslateX': translate_x_abs,
|
||||
'TranslateY': translate_y_abs,
|
||||
'TranslateXRel': translate_x_rel,
|
||||
'TranslateYRel': translate_y_rel,
|
||||
}
|
||||
self.aug_fn = NAME_TO_OP[name]
|
||||
self.prob = prob
|
||||
self.magnitude = magnitude
|
||||
# If std deviation of magnitude is > 0, we introduce some randomness
|
||||
# in the usually fixed policy and sample magnitude from normal dist
|
||||
# with mean magnitude and std-dev of magnitude_std.
|
||||
# NOTE This is being tested as it's not in paper or reference impl.
|
||||
self.magnitude_std = 0.5 # FIXME add arg/hparam
|
||||
self.kwargs = {
|
||||
'fillcolor':
|
||||
hparams['img_mean'] if 'img_mean' in hparams else _FILL,
|
||||
'resample':
|
||||
hparams['interpolation']
|
||||
if 'interpolation' in hparams else _RANDOM_INTERPOLATION
|
||||
}
|
||||
|
||||
self._get_magnitude(name)
|
||||
|
||||
def _get_magnitude(self, name):
|
||||
if name == 'AutoContrast' or name == 'Equalize' or name == 'Invert':
|
||||
self.level_fn = self.pass_fn
|
||||
elif name == 'Rotate':
|
||||
self.level_fn = self._rotate_level_to_arg
|
||||
elif name == 'Posterize':
|
||||
self.level_fn = self._conversion0
|
||||
elif name == 'Posterize2':
|
||||
self.level_fn = self._conversion1
|
||||
elif name == 'Solarize':
|
||||
self.level_fn = self._conversion2
|
||||
elif name == 'SolarizeAdd':
|
||||
self.level_fn = self._conversion3
|
||||
elif name in ['Color', 'Contrast', 'Brightness', 'Sharpness']:
|
||||
self.level_fn = self._enhance_level_to_arg
|
||||
elif name == 'ShearX' or name == 'ShearY':
|
||||
self.level_fn = self._shear_level_to_arg
|
||||
elif name == 'TranslateX' or name == 'TranslateY':
|
||||
self.level_fn = self._translate_abs_level_to_arg2
|
||||
elif name == 'TranslateXRel' or name == 'TranslateYRel':
|
||||
self.level_fn = self._translate_rel_level_to_arg
|
||||
else:
|
||||
print('{} not recognized'.format({}))
|
||||
|
||||
magnitude = self.magnitude
|
||||
if self.magnitude_std and self.magnitude_std > 0:
|
||||
magnitude = random.gauss(magnitude, self.magnitude_std)
|
||||
magnitude = min(_MAX_LEVEL, max(0, magnitude))
|
||||
self.level_args = self.level_fn(magnitude)
|
||||
|
||||
def _rotate_level_to_arg(self, level):
|
||||
# range [-30, 30]
|
||||
level = (level / _MAX_LEVEL) * 30.
|
||||
level = _randomly_negate(level)
|
||||
return (level, )
|
||||
|
||||
def _enhance_level_to_arg(self, level):
|
||||
# range [0.1, 1.9]
|
||||
return ((level / _MAX_LEVEL) * 1.8 + 0.1, )
|
||||
|
||||
def _shear_level_to_arg(self, level):
|
||||
# range [-0.3, 0.3]
|
||||
level = (level / _MAX_LEVEL) * 0.3
|
||||
level = _randomly_negate(level)
|
||||
return (level, )
|
||||
|
||||
def _translate_abs_level_to_arg2(self, level):
|
||||
level = (level / _MAX_LEVEL) * float(
|
||||
_HPARAMS_DEFAULT['translate_const'])
|
||||
level = _randomly_negate(level)
|
||||
return (level, )
|
||||
|
||||
def _translate_rel_level_to_arg(self, level):
|
||||
# range [-0.45, 0.45]
|
||||
level = (level / _MAX_LEVEL) * 0.45
|
||||
level = _randomly_negate(level)
|
||||
return (level, )
|
||||
|
||||
def pass_fn(self, input):
|
||||
return ()
|
||||
|
||||
def _conversion0(self, input):
|
||||
return (int((input / _MAX_LEVEL) * 4) + 4, )
|
||||
|
||||
def _conversion1(self, input):
|
||||
return (4 - int((input / _MAX_LEVEL) * 4), )
|
||||
|
||||
def _conversion2(self, input):
|
||||
return (int((input / _MAX_LEVEL) * 256), )
|
||||
|
||||
def _conversion3(self, input):
|
||||
return (int((input / _MAX_LEVEL) * 110), )
|
||||
|
||||
def transform(self, results):
|
||||
if self.prob < random.random():
|
||||
return results
|
||||
|
||||
for key in results.get('img_fields', ['img']):
|
||||
img = Image.fromarray(results[key])
|
||||
img = self.aug_fn(img, *self.level_args, **self.kwargs)
|
||||
results[key] = np.array(img)
|
||||
return results
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class AutoAugment(BaseTransform):
|
||||
"""Auto Augment Implementation adapted from timm: ImageNet
|
||||
auto_augment_policy is 'original': From TPU EfficientNet impl
|
||||
https://github.com/rwightman/pytorch-image-models.
|
||||
|
||||
ImageNet auto_augment_policy is 'v0':
|
||||
A PyTorch implementation of : `AutoAugment: Learning Augmentation
|
||||
Policies from Data <https://arxiv.org/abs/1805.09501>`_
|
||||
"""
|
||||
auto_augment_policy = {
|
||||
'original': [
|
||||
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
||||
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
||||
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
|
||||
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
|
||||
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
|
||||
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
|
||||
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
|
||||
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
|
||||
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
|
||||
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
||||
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
||||
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
||||
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
||||
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
||||
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
||||
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
|
||||
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
|
||||
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
|
||||
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
|
||||
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
||||
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
||||
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
||||
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
|
||||
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
||||
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
||||
],
|
||||
'v0': [
|
||||
[('Posterize', 0.4, 8), ('Rotate', 0.6, 9)],
|
||||
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||
[('Posterize', 0.6, 7), ('Posterize', 0.6, 6)],
|
||||
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
||||
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
|
||||
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
|
||||
[('Posterize', 0.8, 5), ('Equalize', 1.0, 2)],
|
||||
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
|
||||
[('Equalize', 0.6, 8), ('Posterize', 0.4, 6)],
|
||||
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
|
||||
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
|
||||
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
|
||||
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
||||
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
|
||||
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
|
||||
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
|
||||
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
|
||||
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
|
||||
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
||||
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
||||
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
||||
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
||||
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
||||
]
|
||||
}
|
||||
|
||||
def __init__(self, policies: str = 'original', extra_params: dict = None):
|
||||
self.policies = copy.deepcopy(self.auto_augment_policy[policies])
|
||||
extra_params = extra_params if extra_params else dict(
|
||||
translate_const=250, img_mean=_FILL)
|
||||
self.sub_policy = [[AutoAugmentOp(*a, extra_params) for a in sp]
|
||||
for sp in self.policies]
|
||||
|
||||
def transform(self, results: dict) -> Optional[dict]:
|
||||
sub_policy = random.choice(self.sub_policy)
|
||||
for op in sub_policy:
|
||||
results = op(results)
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(policies={self.policies})'
|
||||
return repr_str
|
|
@ -106,6 +106,7 @@ class AutoAugmentOp(object):
|
|||
|
||||
@TRANSFORMS.register_module()
|
||||
class ShearX(AutoAugmentOp):
|
||||
"""ShearX images."""
|
||||
|
||||
def __init__(self, prob, magnitude, extra_params: dict):
|
||||
super().__init__(prob, magnitude, extra_params)
|
||||
|
@ -129,6 +130,7 @@ class ShearX(AutoAugmentOp):
|
|||
|
||||
@TRANSFORMS.register_module()
|
||||
class ShearY(AutoAugmentOp):
|
||||
"""ShearY images."""
|
||||
|
||||
def __init__(self, prob, magnitude, extra_params: dict):
|
||||
super().__init__(prob, magnitude, extra_params)
|
||||
|
@ -152,6 +154,7 @@ class ShearY(AutoAugmentOp):
|
|||
|
||||
@TRANSFORMS.register_module()
|
||||
class TranslateXRel(AutoAugmentOp):
|
||||
"""TranslateXRel images."""
|
||||
|
||||
def __init__(self, prob, magnitude, extra_params: dict):
|
||||
super().__init__(prob, magnitude, extra_params)
|
||||
|
@ -175,6 +178,7 @@ class TranslateXRel(AutoAugmentOp):
|
|||
|
||||
@TRANSFORMS.register_module()
|
||||
class TranslateYRel(AutoAugmentOp):
|
||||
"""TranslateYRel images."""
|
||||
|
||||
def __init__(self, prob, magnitude, extra_params: dict):
|
||||
super().__init__(prob, magnitude, extra_params)
|
||||
|
@ -198,6 +202,7 @@ class TranslateYRel(AutoAugmentOp):
|
|||
|
||||
@TRANSFORMS.register_module()
|
||||
class TranslateX(AutoAugmentOp):
|
||||
"""TranslateX images."""
|
||||
|
||||
def __init__(self, prob, magnitude, extra_params: dict):
|
||||
super().__init__(prob, magnitude, extra_params)
|
||||
|
@ -221,6 +226,7 @@ class TranslateX(AutoAugmentOp):
|
|||
|
||||
@TRANSFORMS.register_module()
|
||||
class TranslateY(AutoAugmentOp):
|
||||
"""TranslateY images."""
|
||||
|
||||
def __init__(self, prob, magnitude, extra_params: dict):
|
||||
super().__init__(prob, magnitude, extra_params)
|
||||
|
@ -244,6 +250,7 @@ class TranslateY(AutoAugmentOp):
|
|||
|
||||
@TRANSFORMS.register_module()
|
||||
class RotateV2(AutoAugmentOp):
|
||||
"""Rotate images."""
|
||||
|
||||
def __init__(self, prob, magnitude, extra_params: dict):
|
||||
super().__init__(prob, magnitude, extra_params)
|
||||
|
@ -292,6 +299,7 @@ class RotateV2(AutoAugmentOp):
|
|||
|
||||
@TRANSFORMS.register_module()
|
||||
class AutoContrastV2(AutoAugmentOp):
|
||||
"""AutoContrast images."""
|
||||
|
||||
def __call__(self, results, **__):
|
||||
if self.prob < random.random():
|
||||
|
@ -307,6 +315,7 @@ class AutoContrastV2(AutoAugmentOp):
|
|||
|
||||
@TRANSFORMS.register_module()
|
||||
class InvertV2(AutoAugmentOp):
|
||||
"""Invert images."""
|
||||
|
||||
def __call__(self, results, **__):
|
||||
if self.prob < random.random():
|
||||
|
@ -322,6 +331,7 @@ class InvertV2(AutoAugmentOp):
|
|||
|
||||
@TRANSFORMS.register_module()
|
||||
class EqualizeV2(AutoAugmentOp):
|
||||
"""Equalize images."""
|
||||
|
||||
def __call__(self, results, **__):
|
||||
if self.prob < random.random():
|
||||
|
@ -337,6 +347,7 @@ class EqualizeV2(AutoAugmentOp):
|
|||
|
||||
@TRANSFORMS.register_module()
|
||||
class SolarizeV2(AutoAugmentOp):
|
||||
"""Solarize images."""
|
||||
|
||||
def __init__(self, prob, magnitude, extra_params: dict):
|
||||
super().__init__(prob, magnitude, extra_params)
|
||||
|
@ -360,6 +371,7 @@ class SolarizeV2(AutoAugmentOp):
|
|||
|
||||
@TRANSFORMS.register_module()
|
||||
class SolarizeAddV2(AutoAugmentOp):
|
||||
"""SolarizeAdd images."""
|
||||
|
||||
def __init__(self, prob, magnitude, extra_params: dict):
|
||||
super().__init__(prob, magnitude, extra_params)
|
||||
|
@ -393,6 +405,7 @@ class SolarizeAddV2(AutoAugmentOp):
|
|||
|
||||
@TRANSFORMS.register_module()
|
||||
class PosterizeV2(AutoAugmentOp):
|
||||
"""Posterize images."""
|
||||
|
||||
def __init__(self, prob, magnitude, extra_params: dict):
|
||||
super().__init__(prob, magnitude, extra_params)
|
||||
|
@ -419,6 +432,7 @@ class PosterizeV2(AutoAugmentOp):
|
|||
|
||||
@TRANSFORMS.register_module()
|
||||
class ContrastV2(AutoAugmentOp):
|
||||
"""Contrast images."""
|
||||
|
||||
def __init__(self, prob, magnitude, extra_params: dict):
|
||||
super().__init__(prob, magnitude, extra_params)
|
||||
|
@ -440,6 +454,7 @@ class ContrastV2(AutoAugmentOp):
|
|||
|
||||
@TRANSFORMS.register_module()
|
||||
class Color(AutoAugmentOp):
|
||||
"""Color images."""
|
||||
|
||||
def __init__(self, prob, magnitude, extra_params: dict):
|
||||
super().__init__(prob, magnitude, extra_params)
|
||||
|
@ -461,6 +476,7 @@ class Color(AutoAugmentOp):
|
|||
|
||||
@TRANSFORMS.register_module()
|
||||
class BrightnessV2(AutoAugmentOp):
|
||||
"""Brightness images."""
|
||||
|
||||
def __init__(self, prob, magnitude, extra_params: dict):
|
||||
super().__init__(prob, magnitude, extra_params)
|
||||
|
@ -482,6 +498,7 @@ class BrightnessV2(AutoAugmentOp):
|
|||
|
||||
@TRANSFORMS.register_module()
|
||||
class SharpnessV2(AutoAugmentOp):
|
||||
"""Sharpness images."""
|
||||
|
||||
def __init__(self, prob, magnitude, extra_params: dict):
|
||||
super().__init__(prob, magnitude, extra_params)
|
||||
|
|
|
@ -1,60 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.runner import TestLoop
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmrazor.models.utils import add_prefix
|
||||
from mmrazor.registry import LOOPS
|
||||
from .mixins import CalibrateBNMixin
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
class AutoSlimTestLoop(TestLoop, CalibrateBNMixin):
|
||||
|
||||
def __init__(self,
|
||||
runner,
|
||||
dataloader: Union[DataLoader, Dict],
|
||||
evaluator: Union[Evaluator, Dict, List],
|
||||
fp16: bool = False,
|
||||
calibrate_sample_num: int = 4096) -> None:
|
||||
super().__init__(runner, dataloader, evaluator, fp16)
|
||||
|
||||
if self.runner.distributed:
|
||||
model = self.runner.model.module
|
||||
else:
|
||||
model = self.runner.model
|
||||
|
||||
# just for convenience
|
||||
self._model = model
|
||||
self.calibrate_sample_num = calibrate_sample_num
|
||||
|
||||
def run(self) -> None:
|
||||
"""Launch validation."""
|
||||
self.runner.call_hook('before_test')
|
||||
|
||||
all_metrics = dict()
|
||||
|
||||
self._model.set_max_subnet()
|
||||
self.calibrate_bn_statistics(self.runner.train_dataloader,
|
||||
self.calibrate_sample_num)
|
||||
metrics = self._evaluate_once()
|
||||
all_metrics.update(add_prefix(metrics, 'max_subnet'))
|
||||
|
||||
# self._model.set_min_subnet()
|
||||
# self.calibrate_bn_statistics(self.runner.train_dataloader,
|
||||
# self.calibrate_sample_num)
|
||||
# metrics = self._evaluate_once()
|
||||
# all_metrics.update(add_prefix(metrics, 'min_subnet'))
|
||||
|
||||
self.runner.call_hook('after_test_epoch', metrics=all_metrics)
|
||||
self.runner.call_hook('after_test')
|
||||
|
||||
def _evaluate_once(self) -> Dict:
|
||||
self.runner.call_hook('before_test_epoch')
|
||||
self.runner.model.eval()
|
||||
for idx, data_batch in enumerate(self.dataloader):
|
||||
self.run_iter(idx, data_batch)
|
||||
|
||||
return self.evaluator.evaluate(len(self.dataloader.dataset))
|
|
@ -1,5 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
|
@ -17,7 +16,7 @@ from mmengine.utils import is_list_of
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmrazor.registry import LOOPS, TASK_UTILS
|
||||
from mmrazor.structures import Candidates, export_fix_subnet, load_fix_subnet
|
||||
from mmrazor.structures import Candidates, export_fix_subnet
|
||||
from mmrazor.utils import SupportRandomSubnet
|
||||
from .utils import CalibrateBNMixin, check_subnet_resources, crossover
|
||||
|
||||
|
@ -308,41 +307,35 @@ class EvolutionSearchLoop(EpochBasedTrainLoop, CalibrateBNMixin):
|
|||
self.runner.logger.info(f'Resume from epoch: {epoch_start}')
|
||||
self.runner.logger.info('#' * 100)
|
||||
|
||||
def export_model(self, model, best_fix_subnet):
|
||||
"""Export slimmed model according to best_fix_subnet."""
|
||||
copied_model = copy.deepcopy(model)
|
||||
load_fix_subnet(copied_model, best_fix_subnet)
|
||||
if next(copied_model.parameters()).is_cuda:
|
||||
copied_model.cuda()
|
||||
timestamp_subnet = time.strftime('%Y%m%d_%H%M', time.localtime())
|
||||
model_save_name = f'final_subnet_{timestamp_subnet}.pth'
|
||||
state_dict = copied_model.state_dict()
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
new_state_dict[k] = v
|
||||
torch.save({
|
||||
'state_dict': new_state_dict,
|
||||
'meta': {}
|
||||
}, osp.join(self.runner.work_dir, model_save_name))
|
||||
return model_save_name
|
||||
|
||||
def _save_best_fix_subnet(self):
|
||||
"""Save best subnet in searched top-k candidates."""
|
||||
if self.runner.rank == 0:
|
||||
best_random_subnet = self.top_k_candidates.subnets[0]
|
||||
self.model.set_subnet(best_random_subnet)
|
||||
|
||||
best_fix_subnet = export_fix_subnet(self.model)
|
||||
best_fix_subnet = self.convert_fix_subnet(best_fix_subnet)
|
||||
model_to_save = self.export_model(self.model, best_fix_subnet)
|
||||
best_fix_subnet, sliced_model = \
|
||||
export_fix_subnet(self.model, slice_weight=True)
|
||||
|
||||
timestamp_subnet = time.strftime('%Y%m%d_%H%M', time.localtime())
|
||||
model_name = f'subnet_{timestamp_subnet}.pth'
|
||||
save_path = osp.join(self.runner.work_dir, model_name)
|
||||
torch.save({
|
||||
'state_dict': sliced_model.state_dict(),
|
||||
'meta': {}
|
||||
}, save_path)
|
||||
self.runner.logger.info(f'Subnet checkpoint {model_name} saved in '
|
||||
f'{self.runner.work_dir}')
|
||||
|
||||
save_name = 'best_fix_subnet.yaml'
|
||||
best_fix_subnet = self._convert_fix_subnet(best_fix_subnet)
|
||||
fileio.dump(best_fix_subnet,
|
||||
osp.join(self.runner.work_dir, save_name))
|
||||
self.runner.logger.info(
|
||||
f'Search finished and {save_name} '
|
||||
f'{model_to_save} saved in {self.runner.work_dir}.')
|
||||
f'Subnet config {save_name} saved in {self.runner.work_dir}.')
|
||||
|
||||
def convert_fix_subnet(self, fix_subnet: Dict[str, Any]):
|
||||
self.runner.logger.info('Search finished.')
|
||||
|
||||
def _convert_fix_subnet(self, fix_subnet: Dict[str, Any]):
|
||||
"""Convert the fixed subnet to avoid python typing error."""
|
||||
from mmrazor.utils.typing import DumpChosen
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.hooks import CheckpointHook
|
||||
|
@ -7,7 +7,7 @@ from mmengine.runner import ValLoop
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmrazor.models.utils import add_prefix
|
||||
from mmrazor.registry import LOOPS
|
||||
from mmrazor.registry import LOOPS, TASK_UTILS
|
||||
from .utils import CalibrateBNMixin
|
||||
|
||||
|
||||
|
@ -27,15 +27,20 @@ class SubnetValLoop(ValLoop, CalibrateBNMixin):
|
|||
calibrate_sample_num (int): The number of images to compute the true
|
||||
average of per-batch mean/variance instead of the running average.
|
||||
Defaults to 4096.
|
||||
estimator_cfg (dict, Optional): Used for building a resource estimator.
|
||||
Defaults to dict(type='mmrazor.ResourceEstimator').
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runner,
|
||||
dataloader: Union[DataLoader, Dict],
|
||||
evaluator: Union[Evaluator, Dict, List],
|
||||
fp16: bool = False,
|
||||
evaluate_fixed_subnet: bool = False,
|
||||
calibrate_sample_num: int = 4096) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
runner,
|
||||
dataloader: Union[DataLoader, Dict],
|
||||
evaluator: Union[Evaluator, Dict, List],
|
||||
fp16: bool = False,
|
||||
evaluate_fixed_subnet: bool = False,
|
||||
calibrate_sample_num: int = 4096,
|
||||
estimator_cfg: Optional[Dict] = dict(type='mmrazor.ResourceEstimator')
|
||||
) -> None:
|
||||
super().__init__(runner, dataloader, evaluator, fp16)
|
||||
|
||||
if self.runner.distributed:
|
||||
|
@ -43,10 +48,10 @@ class SubnetValLoop(ValLoop, CalibrateBNMixin):
|
|||
else:
|
||||
model = self.runner.model
|
||||
|
||||
# just for convenience
|
||||
self._model = model
|
||||
self.model = model
|
||||
self.evaluate_fixed_subnet = evaluate_fixed_subnet
|
||||
self.calibrate_sample_num = calibrate_sample_num
|
||||
self.estimator = TASK_UTILS.build(estimator_cfg)
|
||||
|
||||
# remove CheckpointHook to avoid extra problems.
|
||||
for hook in self.runner._hooks:
|
||||
|
@ -64,23 +69,20 @@ class SubnetValLoop(ValLoop, CalibrateBNMixin):
|
|||
if self.evaluate_fixed_subnet:
|
||||
metrics = self._evaluate_once()
|
||||
all_metrics.update(add_prefix(metrics, 'fix_subnet'))
|
||||
else:
|
||||
self._model.set_max_subnet()
|
||||
metrics = self._evaluate_once()
|
||||
all_metrics.update(add_prefix(metrics, 'max_subnet'))
|
||||
|
||||
self._model.set_min_subnet()
|
||||
metrics = self._evaluate_once()
|
||||
all_metrics.update(add_prefix(metrics, 'min_subnet'))
|
||||
|
||||
sample_nums = self._model.random_samples if hasattr(
|
||||
self._model, 'random_samples') else self._model.samples
|
||||
for subnet_idx in range(sample_nums):
|
||||
self._model.set_subnet(self._model.sample_subnet())
|
||||
# compute student metrics
|
||||
metrics = self._evaluate_once()
|
||||
all_metrics.update(
|
||||
add_prefix(metrics, f'random_subnet_{subnet_idx}'))
|
||||
elif hasattr(self.model, 'sample_kinds'):
|
||||
for kind in self.model.sample_kinds:
|
||||
if kind == 'max':
|
||||
self.model.set_max_subnet()
|
||||
metrics = self._evaluate_once()
|
||||
all_metrics.update(add_prefix(metrics, 'max_subnet'))
|
||||
elif kind == 'min':
|
||||
self.model.set_min_subnet()
|
||||
metrics = self._evaluate_once()
|
||||
all_metrics.update(add_prefix(metrics, 'min_subnet'))
|
||||
elif 'random' in kind:
|
||||
self.model.set_subnet(self.model.sample_subnet())
|
||||
metrics = self._evaluate_once()
|
||||
all_metrics.update(add_prefix(metrics, f'{kind}_subnet'))
|
||||
|
||||
self.runner.call_hook('after_val_epoch', metrics=all_metrics)
|
||||
self.runner.call_hook('after_val')
|
||||
|
@ -93,4 +95,8 @@ class SubnetValLoop(ValLoop, CalibrateBNMixin):
|
|||
for idx, data_batch in enumerate(self.dataloader):
|
||||
self.run_iter(idx, data_batch)
|
||||
|
||||
return self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||
resource_metrics = self.estimator.estimate(self.model)
|
||||
metrics.update(resource_metrics)
|
||||
|
||||
return metrics
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from mmrazor.models import ResourceEstimator
|
||||
from mmrazor.structures import export_fix_subnet, load_fix_subnet
|
||||
from mmrazor.structures import export_fix_subnet
|
||||
from mmrazor.utils import SupportRandomSubnet
|
||||
|
||||
try:
|
||||
|
@ -32,11 +31,9 @@ def check_subnet_resources(
|
|||
|
||||
assert hasattr(model, 'set_subnet') and hasattr(model, 'architecture')
|
||||
model.set_subnet(subnet)
|
||||
fix_mutable = export_fix_subnet(model)
|
||||
copied_model = copy.deepcopy(model)
|
||||
load_fix_subnet(copied_model, fix_mutable)
|
||||
_, sliced_model = export_fix_subnet(model, slice_weight=True)
|
||||
|
||||
model_to_check = model.architecture
|
||||
model_to_check = sliced_model.architecture # type: ignore
|
||||
if isinstance(model_to_check, BaseDetector):
|
||||
results = estimator.estimate(model=model_to_check.backbone)
|
||||
else:
|
||||
|
|
|
@ -109,6 +109,5 @@ class Autoformer(BaseAlgorithm):
|
|||
) -> LossResults:
|
||||
"""Calculate losses from a batch of inputs and data samples."""
|
||||
if self.is_supernet:
|
||||
random_subnet = self.sample_subnet()
|
||||
self.set_subnet(random_subnet)
|
||||
self.set_subnet(self.sample_subnet())
|
||||
return self.architecture(batch_inputs, data_samples, mode='loss')
|
||||
|
|
|
@ -24,28 +24,28 @@ VALID_CHANNEL_CFG_PATH_TYPE = Union[VALID_PATH_TYPE, List[VALID_PATH_TYPE]]
|
|||
|
||||
@MODELS.register_module()
|
||||
class AutoSlim(BaseAlgorithm):
|
||||
"""Implementation of Autoslim algorithm. Please refer to
|
||||
https://arxiv.org/abs/1903.11728 for more details.
|
||||
|
||||
Args:
|
||||
mutator (VALID_MUTATOR_TYPE): config of mutator.
|
||||
distiller (VALID_DISTILLER_TYPE): config of distiller.
|
||||
architecture (Union[BaseModel, Dict]): the model to be searched.
|
||||
data_preprocessor (Optional[Union[Dict, nn.Module]], optional):
|
||||
data prepocessor. Defaults to None.
|
||||
num_random_samples (int): number of random sample subnets.
|
||||
Defaults to 2.
|
||||
init_cfg (Optional[Dict], optional): config of initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mutator: VALID_MUTATOR_TYPE,
|
||||
distiller: VALID_DISTILLER_TYPE,
|
||||
architecture: Union[BaseModel, Dict],
|
||||
num_random_samples: int = 2,
|
||||
data_preprocessor: Optional[Union[Dict, nn.Module]] = None,
|
||||
init_cfg: Optional[Dict] = None,
|
||||
num_samples: int = 2) -> None:
|
||||
"""Implementation of Autoslim algorithm. Please refer to
|
||||
https://arxiv.org/abs/1903.11728 for more details.
|
||||
|
||||
Args:
|
||||
mutator (VALID_MUTATOR_TYPE): config of mutator.
|
||||
distiller (VALID_DISTILLER_TYPE): config of distiller.
|
||||
architecture (Union[BaseModel, Dict]): the model to be searched.
|
||||
data_preprocessor (Optional[Union[Dict, nn.Module]], optional):
|
||||
data prepocessor. Defaults to None.
|
||||
init_cfg (Optional[Dict], optional): config of initialization.
|
||||
Defaults to None.
|
||||
num_samples (int, optional): number of sample subnets.
|
||||
Defaults to 2.
|
||||
"""
|
||||
init_cfg: Optional[Dict] = None) -> None:
|
||||
super().__init__(architecture, data_preprocessor, init_cfg)
|
||||
|
||||
self.mutator: OneShotChannelMutator = MODELS.build(mutator)
|
||||
|
@ -56,7 +56,9 @@ class AutoSlim(BaseAlgorithm):
|
|||
self.distiller.prepare_from_teacher(self.architecture)
|
||||
self.distiller.prepare_from_student(self.architecture)
|
||||
|
||||
self.num_samples = num_samples
|
||||
self.sample_kinds = ['max', 'min']
|
||||
for i in range(num_random_samples):
|
||||
self.sample_kinds.append('random' + str(i))
|
||||
|
||||
self._optim_wrapper_count_status_reinitialized = False
|
||||
|
||||
|
@ -125,7 +127,7 @@ class AutoSlim(BaseAlgorithm):
|
|||
reinitialize_optim_wrapper_count_status(
|
||||
model=self,
|
||||
optim_wrapper=optim_wrapper,
|
||||
accumulative_counts=self.num_samples + 2)
|
||||
accumulative_counts=len(self.sample_kinds))
|
||||
self._optim_wrapper_count_status_reinitialized = True
|
||||
|
||||
input_data = self.data_preprocessor(data, True)
|
||||
|
@ -133,24 +135,32 @@ class AutoSlim(BaseAlgorithm):
|
|||
data_samples = input_data['data_samples']
|
||||
|
||||
total_losses = dict()
|
||||
self.set_max_subnet()
|
||||
with optim_wrapper.optim_context(
|
||||
self), self.distiller.teacher_recorders: # type: ignore
|
||||
max_subnet_losses = self(batch_inputs, data_samples, mode='loss')
|
||||
parsed_max_subnet_losses, _ = self.parse_losses(max_subnet_losses)
|
||||
optim_wrapper.update_params(parsed_max_subnet_losses)
|
||||
total_losses.update(add_prefix(max_subnet_losses, 'max_subnet'))
|
||||
|
||||
self.set_min_subnet()
|
||||
min_subnet_losses = distill_step(batch_inputs, data_samples)
|
||||
total_losses.update(add_prefix(min_subnet_losses, 'min_subnet'))
|
||||
|
||||
for sample_idx in range(self.num_samples):
|
||||
self.set_subnet(self.sample_subnet())
|
||||
random_subnet_losses = distill_step(batch_inputs, data_samples)
|
||||
total_losses.update(
|
||||
add_prefix(random_subnet_losses,
|
||||
f'random_subnet_{sample_idx}'))
|
||||
for kind in self.sample_kinds:
|
||||
# update the max subnet loss.
|
||||
if kind == 'max':
|
||||
self.set_max_subnet()
|
||||
with optim_wrapper.optim_context(
|
||||
self
|
||||
), self.distiller.teacher_recorders: # type: ignore
|
||||
max_subnet_losses = self(
|
||||
batch_inputs, data_samples, mode='loss')
|
||||
parsed_max_subnet_losses, _ = self.parse_losses(
|
||||
max_subnet_losses)
|
||||
optim_wrapper.update_params(parsed_max_subnet_losses)
|
||||
total_losses.update(
|
||||
add_prefix(max_subnet_losses, 'max_subnet'))
|
||||
# update the min subnet loss.
|
||||
elif kind == 'min':
|
||||
self.set_min_subnet()
|
||||
min_subnet_losses = distill_step(batch_inputs, data_samples)
|
||||
total_losses.update(
|
||||
add_prefix(min_subnet_losses, 'min_subnet'))
|
||||
# update the random subnets loss.
|
||||
elif 'random' in kind:
|
||||
self.set_subnet(self.sample_subnet())
|
||||
random_subnet_losses = distill_step(batch_inputs, data_samples)
|
||||
total_losses.update(
|
||||
add_prefix(random_subnet_losses, f'{kind}_subnet'))
|
||||
|
||||
return total_losses
|
||||
|
||||
|
@ -194,7 +204,7 @@ class AutoSlimDDP(MMDistributedDataParallel):
|
|||
reinitialize_optim_wrapper_count_status(
|
||||
model=self,
|
||||
optim_wrapper=optim_wrapper,
|
||||
accumulative_counts=self.module.num_samples + 2)
|
||||
accumulative_counts=len(self.module.sample_kinds))
|
||||
self._optim_wrapper_count_status_reinitialized = True
|
||||
|
||||
input_data = self.module.data_preprocessor(data, True)
|
||||
|
@ -202,25 +212,32 @@ class AutoSlimDDP(MMDistributedDataParallel):
|
|||
data_samples = input_data['data_samples']
|
||||
|
||||
total_losses = dict()
|
||||
self.module.set_max_subnet()
|
||||
with optim_wrapper.optim_context(
|
||||
self), self.module.distiller.teacher_recorders: # type: ignore
|
||||
max_subnet_losses = self(batch_inputs, data_samples, mode='loss')
|
||||
parsed_max_subnet_losses, _ = self.module.parse_losses(
|
||||
max_subnet_losses)
|
||||
optim_wrapper.update_params(parsed_max_subnet_losses)
|
||||
total_losses.update(add_prefix(max_subnet_losses, 'max_subnet'))
|
||||
|
||||
self.module.set_min_subnet()
|
||||
min_subnet_losses = distill_step(batch_inputs, data_samples)
|
||||
total_losses.update(add_prefix(min_subnet_losses, 'min_subnet'))
|
||||
|
||||
for sample_idx in range(self.module.num_samples):
|
||||
self.module.set_subnet(self.module.sample_subnet())
|
||||
random_subnet_losses = distill_step(batch_inputs, data_samples)
|
||||
total_losses.update(
|
||||
add_prefix(random_subnet_losses,
|
||||
f'random_subnet_{sample_idx}'))
|
||||
for kind in self.module.sample_kinds:
|
||||
# update the max subnet loss.
|
||||
if kind == 'max':
|
||||
self.module.set_max_subnet()
|
||||
with optim_wrapper.optim_context(
|
||||
self
|
||||
), self.module.distiller.teacher_recorders: # type: ignore
|
||||
max_subnet_losses = self(
|
||||
batch_inputs, data_samples, mode='loss')
|
||||
parsed_max_subnet_losses, _ = self.module.parse_losses(
|
||||
max_subnet_losses)
|
||||
optim_wrapper.update_params(parsed_max_subnet_losses)
|
||||
total_losses.update(
|
||||
add_prefix(max_subnet_losses, 'max_subnet'))
|
||||
# update the min subnet loss.
|
||||
elif kind == 'min':
|
||||
self.module.set_min_subnet()
|
||||
min_subnet_losses = distill_step(batch_inputs, data_samples)
|
||||
total_losses.update(
|
||||
add_prefix(min_subnet_losses, 'min_subnet'))
|
||||
# update the random subnets loss.
|
||||
elif 'random' in kind:
|
||||
self.module.set_subnet(self.module.sample_subnet())
|
||||
random_subnet_losses = distill_step(batch_inputs, data_samples)
|
||||
total_losses.update(
|
||||
add_prefix(random_subnet_losses, f'{kind}_subnet'))
|
||||
|
||||
return total_losses
|
||||
|
||||
|
|
|
@ -48,7 +48,8 @@ class BigNAS(BaseAlgorithm):
|
|||
loaded dict or built :obj:`FixSubnet`. Defaults to None.
|
||||
data_preprocessor (Optional[Union[dict, nn.Module]]): The pre-process
|
||||
config of :class:`BaseDataPreprocessor`. Defaults to None.
|
||||
strategy (str): The sampling strategy. Defaults to `sandwich4`.
|
||||
num_random_samples (int): number of random sample subnets.
|
||||
Defaults to 2.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.2.
|
||||
backbone_dropout_stages (List): Stages to be set dropout. Defaults to
|
||||
[6, 7].
|
||||
|
@ -62,19 +63,13 @@ class BigNAS(BaseAlgorithm):
|
|||
the mutable object ``OneShotMutableChannel`` in BigNAS.
|
||||
"""
|
||||
|
||||
strategy_groups: Dict[str, List] = {
|
||||
'sandwich2': ['max', 'min'],
|
||||
'sandwich3': ['max', 'random0', 'min'],
|
||||
'sandwich4': ['max', 'random0', 'random1', 'min']
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
architecture: Union[BaseModel, Dict],
|
||||
mutators: VALID_MUTATORS_TYPE,
|
||||
distiller: VALID_DISTILLER_TYPE,
|
||||
fix_subnet: Optional[ValidFixMutable] = None,
|
||||
data_preprocessor: Optional[Union[Dict, nn.Module]] = None,
|
||||
strategy: str = 'sandwich4',
|
||||
num_random_samples: int = 2,
|
||||
drop_path_rate: float = 0.2,
|
||||
backbone_dropout_stages: List = [6, 7],
|
||||
init_cfg: Optional[Dict] = None) -> None:
|
||||
|
@ -100,13 +95,14 @@ class BigNAS(BaseAlgorithm):
|
|||
self.distiller.prepare_from_teacher(self.architecture)
|
||||
self.distiller.prepare_from_student(self.architecture)
|
||||
|
||||
self.strategy = strategy
|
||||
self.selects = self.strategy_groups[self.strategy]
|
||||
self.random_samples = len([s for s in self.selects if 'random' in s])
|
||||
self.is_supernet = True if len(self.selects) > 1 else False
|
||||
self.sample_kinds = ['max', 'min']
|
||||
for i in range(num_random_samples):
|
||||
self.sample_kinds.append('random' + str(i))
|
||||
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.backbone_dropout_stages = backbone_dropout_stages
|
||||
self._optim_wrapper_count_status_reinitialized = False
|
||||
self.is_supernet = True
|
||||
|
||||
if fix_subnet:
|
||||
# Avoid circular import
|
||||
|
@ -197,15 +193,16 @@ class BigNAS(BaseAlgorithm):
|
|||
reinitialize_optim_wrapper_count_status(
|
||||
model=self,
|
||||
optim_wrapper=optim_wrapper,
|
||||
accumulative_counts=len(self.selects))
|
||||
accumulative_counts=len(self.sample_kinds))
|
||||
self._optim_wrapper_count_status_reinitialized = True
|
||||
|
||||
batch_inputs, data_samples = self.data_preprocessor(data,
|
||||
True).values()
|
||||
|
||||
total_losses = dict()
|
||||
for kind in self.selects:
|
||||
if kind in ('max'):
|
||||
for kind in self.sample_kinds:
|
||||
# update the max subnet loss.
|
||||
if kind == 'max':
|
||||
self.set_max_subnet()
|
||||
set_dropout(
|
||||
layers=self.architecture.backbone.layers[:-1],
|
||||
|
@ -222,7 +219,8 @@ class BigNAS(BaseAlgorithm):
|
|||
optim_wrapper.update_params(parsed_max_subnet_losses)
|
||||
total_losses.update(
|
||||
add_prefix(max_subnet_losses, 'max_subnet'))
|
||||
elif kind in ('min'):
|
||||
# update the min subnet loss.
|
||||
elif kind == 'min':
|
||||
self.set_min_subnet()
|
||||
set_dropout(
|
||||
layers=self.architecture.backbone.layers[:-1],
|
||||
|
@ -233,7 +231,8 @@ class BigNAS(BaseAlgorithm):
|
|||
data_samples)
|
||||
total_losses.update(
|
||||
add_prefix(min_subnet_losses, 'min_subnet'))
|
||||
elif kind in ('random0', 'random1'):
|
||||
# update the random subnets loss.
|
||||
elif 'random' in kind:
|
||||
self.set_subnet(self.sample_subnet())
|
||||
set_dropout(
|
||||
layers=self.architecture.backbone.layers[:-1],
|
||||
|
@ -288,15 +287,16 @@ class BigNASDDP(MMDistributedDataParallel):
|
|||
reinitialize_optim_wrapper_count_status(
|
||||
model=self,
|
||||
optim_wrapper=optim_wrapper,
|
||||
accumulative_counts=len(self.module.selects))
|
||||
accumulative_counts=len(self.module.sample_kinds))
|
||||
self._optim_wrapper_count_status_reinitialized = True
|
||||
|
||||
batch_inputs, data_samples = self.module.data_preprocessor(
|
||||
data, True).values()
|
||||
|
||||
total_losses = dict()
|
||||
for kind in self.module.selects:
|
||||
if kind in ('max'):
|
||||
for kind in self.module.sample_kinds:
|
||||
# update the max subnet loss.
|
||||
if kind == 'max':
|
||||
self.module.set_max_subnet()
|
||||
set_dropout(
|
||||
layers=self.module.architecture.backbone.layers[:-1],
|
||||
|
@ -313,7 +313,8 @@ class BigNASDDP(MMDistributedDataParallel):
|
|||
optim_wrapper.update_params(parsed_max_subnet_losses)
|
||||
total_losses.update(
|
||||
add_prefix(max_subnet_losses, 'max_subnet'))
|
||||
elif kind in ('min'):
|
||||
# update the min subnet loss.
|
||||
elif kind == 'min':
|
||||
self.module.set_min_subnet()
|
||||
set_dropout(
|
||||
layers=self.module.architecture.backbone.layers[:-1],
|
||||
|
@ -324,8 +325,8 @@ class BigNASDDP(MMDistributedDataParallel):
|
|||
data_samples)
|
||||
total_losses.update(
|
||||
add_prefix(min_subnet_losses, 'min_subnet'))
|
||||
|
||||
elif kind in ('random0', 'random1'):
|
||||
# update the random subnets loss.
|
||||
elif 'random' in kind:
|
||||
self.module.set_subnet(self.module.sample_subnet())
|
||||
set_dropout(
|
||||
layers=self.module.architecture.backbone.layers[:-1],
|
||||
|
|
|
@ -92,7 +92,7 @@ class Darts(BaseAlgorithm):
|
|||
|
||||
subnet = self.mutator.sample_choices()
|
||||
self.mutator.set_choices(subnet)
|
||||
return export_fix_subnet(self)
|
||||
return export_fix_subnet(self)[0]
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into eval mode while keep normalization layer
|
||||
|
|
|
@ -122,7 +122,7 @@ class DSNAS(BaseAlgorithm):
|
|||
|
||||
subnet = self.mutator.sample_choices()
|
||||
self.mutator.set_choices(subnet)
|
||||
return export_fix_subnet(self)
|
||||
return export_fix_subnet(self)[0]
|
||||
|
||||
def fix_subnet(self):
|
||||
"""Fix subnet when finetuning."""
|
||||
|
|
|
@ -35,30 +35,33 @@ class AttentiveMobileNetV3(BaseBackbone):
|
|||
"""Searchable MobileNetV3 backbone.
|
||||
|
||||
Args:
|
||||
arch_setting (list[list]): Architecture settings.
|
||||
arch_setting (Dict[str, List]): Architecture settings.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
channels in each layer by this amount. Default: 1.0.
|
||||
channels in each layer by this amount. Defaults to 1.0.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
Default: (7, ).
|
||||
Defaults to (7, ).
|
||||
frozen_stages (int): Stages to be frozen (all param fixed).
|
||||
Default: -1, which means not freezing any parameters.
|
||||
stride_list (list[list]): stride setting in each stage.
|
||||
Default: None
|
||||
with_se_list (list[list]): Whether to use se-layer in each stage.
|
||||
Default: None
|
||||
Defaults to -1, which means not freezing any parameters.
|
||||
conv_cfg (dict, optional): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
Defaults to None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='Swish').
|
||||
Defaults to dict(type='BN').
|
||||
act_cfg_list (List): Config dict for activation layer.
|
||||
Defaults to None.
|
||||
stride_list (list): stride setting in each stage.
|
||||
Defaults to None.
|
||||
with_se_list (list): Whether to use se-layer in each stage.
|
||||
Defaults to None.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
and its variants only. Defaults to False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
memory while slowing down the training speed. Defaults to False.
|
||||
zero_init_residual (bool): Zero norm param in linear conv of MBBlock
|
||||
or not when there is a shortcut. Default: True.
|
||||
or not when there is a shortcut. Defaults to True.
|
||||
fine_grained_mode (bool): Whether to use fine-grained mode (search
|
||||
kernel size & expand ratio for each MB block in each layers).
|
||||
Defaults to False.
|
||||
with_attentive_shortcut (bool): Use shortcut in AttentiveNAS or not.
|
||||
Defaults to True.
|
||||
init_cfg (dict | list[dict], optional): initialization configuration
|
||||
|
@ -72,14 +75,15 @@ class AttentiveMobileNetV3(BaseBackbone):
|
|||
widen_factor: float = 1.,
|
||||
out_indices: Sequence[int] = (7, ),
|
||||
frozen_stages: int = -1,
|
||||
stride_list: List = None,
|
||||
with_se_list: List = None,
|
||||
conv_cfg: Dict = dict(type='BigNasConv2d'),
|
||||
norm_cfg: Dict = dict(type='DynamicBatchNorm2d'),
|
||||
act_cfg: Dict = dict(type='Swish'),
|
||||
act_cfg_list: List = None,
|
||||
stride_list: List = None,
|
||||
with_se_list: List = None,
|
||||
norm_eval: bool = False,
|
||||
with_cp: bool = False,
|
||||
zero_init_residual: bool = True,
|
||||
fine_grained_mode: bool = False,
|
||||
with_attentive_shortcut: bool = True,
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None):
|
||||
|
||||
|
@ -100,15 +104,17 @@ class AttentiveMobileNetV3(BaseBackbone):
|
|||
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.zero_init_residual = zero_init_residual
|
||||
self.with_cp = with_cp
|
||||
self.fine_grained_mode = fine_grained_mode
|
||||
self.with_attentive_shortcut = with_attentive_shortcut
|
||||
|
||||
self.stride_list = stride_list if stride_list is not None \
|
||||
self.act_cfg_list = act_cfg_list if act_cfg_list \
|
||||
else ['Swish'] * 9
|
||||
self.stride_list = stride_list if stride_list \
|
||||
else [1, 2, 2, 2, 1, 2, 1]
|
||||
self.with_se_list = with_se_list if with_se_list is not None \
|
||||
self.with_se_list = with_se_list if with_se_list \
|
||||
else [False, False, True, False, True, True, True]
|
||||
|
||||
# adapt mutable settings
|
||||
|
@ -123,6 +129,9 @@ class AttentiveMobileNetV3(BaseBackbone):
|
|||
make_divisible(c * widen_factor, 8) for c in channels
|
||||
] for channels in self.num_channels_list]
|
||||
|
||||
self.first_act = self.act_cfg_list.pop(0)
|
||||
self.last_act = self.act_cfg_list.pop(-1)
|
||||
|
||||
self.first_out_channels_list = self.num_channels_list.pop(0)
|
||||
self.last_out_channels_list = self.num_channels_list.pop(-1)
|
||||
self.last_expand_ratio_list = self.expand_ratio_list.pop(-1)
|
||||
|
@ -146,11 +155,7 @@ class AttentiveMobileNetV3(BaseBackbone):
|
|||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type='Swish'))
|
||||
|
||||
self.last_mutable = OneShotMutableChannel(
|
||||
num_channels=self.in_channels,
|
||||
candidate_choices=self.first_out_channels_list)
|
||||
act_cfg=dict(type=self.first_act))
|
||||
|
||||
for i, (num_blocks, kernel_sizes, expand_ratios, num_channels) in \
|
||||
enumerate(zip(self.num_blocks_list, self.kernel_size_list,
|
||||
|
@ -161,7 +166,8 @@ class AttentiveMobileNetV3(BaseBackbone):
|
|||
kernel_sizes=kernel_sizes,
|
||||
expand_ratios=expand_ratios,
|
||||
stride=self.stride_list[i],
|
||||
use_se=self.with_se_list[i])
|
||||
use_se=self.with_se_list[i],
|
||||
act=self.act_cfg_list[i])
|
||||
layer_name = f'layer{i + 1}'
|
||||
self.add_module(layer_name, inverted_res_layer)
|
||||
layers.append(inverted_res_layer)
|
||||
|
@ -178,7 +184,7 @@ class AttentiveMobileNetV3(BaseBackbone):
|
|||
padding=0,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type='Swish'))),
|
||||
act_cfg=dict(type=self.last_act))),
|
||||
('pool', nn.AdaptiveAvgPool2d((1, 1))),
|
||||
('feature_mix_layer',
|
||||
ConvModule(
|
||||
|
@ -189,14 +195,14 @@ class AttentiveMobileNetV3(BaseBackbone):
|
|||
bias=False,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='Swish')))]))
|
||||
act_cfg=dict(type=self.last_act)))]))
|
||||
self.add_module('last_conv', last_layers)
|
||||
layers.append(last_layers)
|
||||
return layers
|
||||
|
||||
def _make_single_layer(self, out_channels: List, num_blocks: List,
|
||||
kernel_sizes: List, expand_ratios: List,
|
||||
stride: int, use_se: bool):
|
||||
stride: int, act: str, use_se: bool):
|
||||
"""Stack InvertedResidual blocks (MBBlocks) to build a layer for
|
||||
MobileNetV3.
|
||||
|
||||
|
@ -229,7 +235,7 @@ class AttentiveMobileNetV3(BaseBackbone):
|
|||
expand_ratio=max(expand_ratios),
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
act_cfg=dict(type=act),
|
||||
with_cp=self.with_cp,
|
||||
se_cfg=se_cfg,
|
||||
with_attentive_shortcut=self.with_attentive_shortcut)
|
||||
|
@ -245,6 +251,15 @@ class AttentiveMobileNetV3(BaseBackbone):
|
|||
OneShotMutableChannelUnit._register_channel_container(
|
||||
self, MutableChannelContainer)
|
||||
|
||||
self.first_mutable_channels = OneShotMutableChannel(
|
||||
alias='backbone.first_channels',
|
||||
num_channels=max(self.first_out_channels_list),
|
||||
candidate_choices=self.first_out_channels_list)
|
||||
|
||||
mutate_conv_module(
|
||||
self.first_conv, mutable_out_channels=self.first_mutable_channels)
|
||||
|
||||
mid_mutable = self.first_mutable_channels
|
||||
# mutate the built mobilenet layers
|
||||
for i, layer in enumerate(self.layers[:-1]):
|
||||
num_blocks = self.num_blocks_list[i]
|
||||
|
@ -252,57 +267,59 @@ class AttentiveMobileNetV3(BaseBackbone):
|
|||
expand_ratios = self.expand_ratio_list[i]
|
||||
out_channels = self.num_channels_list[i]
|
||||
|
||||
mutable_kernel_size = OneShotMutableValue(
|
||||
value_list=kernel_sizes, default_value=max(kernel_sizes))
|
||||
mutable_expand_value = OneShotMutableValue(
|
||||
value_list=expand_ratios, default_value=max(expand_ratios))
|
||||
prefix = 'backbone.layers.' + str(i + 1) + '.'
|
||||
|
||||
mutable_channel_name = 'layer' + str(i +
|
||||
1) + '.mutable_out_channels'
|
||||
setattr(
|
||||
self, mutable_channel_name,
|
||||
OneShotMutableChannel(
|
||||
num_channels=max(out_channels),
|
||||
candidate_choices=out_channels))
|
||||
mutable_out_channels = OneShotMutableChannel(
|
||||
alias=prefix + 'out_channels',
|
||||
candidate_choices=out_channels,
|
||||
num_channels=max(out_channels))
|
||||
|
||||
se_ratios = [i / 4 for i in expand_ratios]
|
||||
mutable_se_channels = OneShotMutableValue(
|
||||
value_list=se_ratios, default_value=max(se_ratios))
|
||||
if not self.fine_grained_mode:
|
||||
mutable_kernel_size = OneShotMutableValue(
|
||||
alias=prefix + 'kernel_size', value_list=kernel_sizes)
|
||||
|
||||
if i == 0:
|
||||
mutate_conv_module(
|
||||
self.first_conv, mutable_out_channels=self.last_mutable)
|
||||
|
||||
for k in range(max(self.num_blocks_list[i])):
|
||||
mutate_mobilenet_layer(layer[k], self.last_mutable,
|
||||
getattr(self, mutable_channel_name),
|
||||
mutable_se_channels,
|
||||
mutable_expand_value,
|
||||
mutable_kernel_size)
|
||||
self.last_mutable = getattr(self, mutable_channel_name)
|
||||
mutable_expand_ratio = OneShotMutableValue(
|
||||
alias=prefix + 'expand_ratio', value_list=expand_ratios)
|
||||
|
||||
mutable_depth = OneShotMutableValue(
|
||||
value_list=num_blocks, default_value=max(num_blocks))
|
||||
alias=prefix + 'depth', value_list=num_blocks)
|
||||
layer.register_mutable_attr('depth', mutable_depth)
|
||||
|
||||
self.last_mutable_out_channels = OneShotMutableChannel(
|
||||
for k in range(max(self.num_blocks_list[i])):
|
||||
|
||||
if self.fine_grained_mode:
|
||||
mutable_kernel_size = OneShotMutableValue(
|
||||
alias=prefix + str(k) + '.kernel_size',
|
||||
value_list=kernel_sizes)
|
||||
|
||||
mutable_expand_ratio = OneShotMutableValue(
|
||||
alias=prefix + str(k) + '.expand_ratio',
|
||||
value_list=expand_ratios)
|
||||
|
||||
mutate_mobilenet_layer(layer[k], mid_mutable,
|
||||
mutable_out_channels,
|
||||
mutable_expand_ratio,
|
||||
mutable_kernel_size)
|
||||
mid_mutable = mutable_out_channels
|
||||
|
||||
self.last_mutable_channels = OneShotMutableChannel(
|
||||
alias='backbone.last_channels',
|
||||
num_channels=self.out_channels,
|
||||
candidate_choices=self.last_out_channels_list)
|
||||
|
||||
last_mutable_expand_value = OneShotMutableValue(
|
||||
value_list=self.last_expand_ratio_list,
|
||||
default_value=max(self.last_expand_ratio_list))
|
||||
|
||||
derived_expand_channels = self.last_mutable * last_mutable_expand_value
|
||||
derived_expand_channels = mid_mutable * last_mutable_expand_value
|
||||
mutate_conv_module(
|
||||
self.layers[-1].final_expand_layer,
|
||||
mutable_in_channels=self.last_mutable,
|
||||
mutable_in_channels=mid_mutable,
|
||||
mutable_out_channels=derived_expand_channels)
|
||||
mutate_conv_module(
|
||||
self.layers[-1].feature_mix_layer,
|
||||
mutable_in_channels=derived_expand_channels,
|
||||
mutable_out_channels=self.last_mutable_out_channels)
|
||||
|
||||
self.last_mutable = self.last_mutable_out_channels
|
||||
mutable_out_channels=self.last_mutable_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.first_conv(x)
|
||||
|
|
|
@ -9,7 +9,6 @@ except ImportError:
|
|||
from torch import Tensor
|
||||
|
||||
from mmrazor.models.architectures.dynamic_ops import DynamicInputResizer
|
||||
from mmrazor.models.mutables import OneShotMutableValue
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
|
||||
|
@ -79,18 +78,20 @@ class SearchableImageClassifier(ImageClassifier):
|
|||
def _build_input_resizer(self,
|
||||
input_resizer_cfg: Dict) -> DynamicInputResizer:
|
||||
"""Build a input resizer."""
|
||||
input_resizer_cfg_ = input_resizer_cfg['input_resizer']
|
||||
input_resizer = MODELS.build(input_resizer_cfg_)
|
||||
if not isinstance(input_resizer, DynamicInputResizer):
|
||||
raise TypeError('input_resizer should be a `dict` or '
|
||||
'`DynamicInputResizer` instance, but got '
|
||||
f'{type(input_resizer)}')
|
||||
mutable_shape_cfg = dict(type='OneShotMutableValue')
|
||||
|
||||
mutable_shape_cfg['alias'] = \
|
||||
input_resizer_cfg.get('alias', 'input_shape')
|
||||
|
||||
assert 'input_sizes' in input_resizer_cfg and \
|
||||
isinstance(input_resizer_cfg['input_sizes'][0], list), (
|
||||
'input_resizer_cfg[`input_sizes`] should be List[list].')
|
||||
mutable_shape_cfg['value_list'] = \
|
||||
input_resizer_cfg.get('input_sizes') # type: ignore
|
||||
|
||||
mutable_shape_cfg = input_resizer_cfg['mutable_shape']
|
||||
mutable_shape = MODELS.build(mutable_shape_cfg)
|
||||
if not isinstance(mutable_shape, OneShotMutableValue):
|
||||
raise ValueError('`mutable_shape` should be instance of '
|
||||
'OneShotMutableValue')
|
||||
|
||||
input_resizer = MODELS.build(dict(type='DynamicInputResizer'))
|
||||
input_resizer.register_mutable_attr('shape', mutable_shape)
|
||||
|
||||
return input_resizer
|
||||
|
|
|
@ -46,13 +46,21 @@ class ShortcutLayer(BaseOP):
|
|||
x = F.avg_pool2d(x, self.reduction, padding=padding)
|
||||
|
||||
# HACK
|
||||
mutable_in_channels = self.conv.mutable_in_channels
|
||||
mutable_out_channels = self.conv.mutable_out_channels
|
||||
if mutable_out_channels is not None and \
|
||||
mutable_in_channels is not None:
|
||||
if mutable_out_channels.current_mask.sum().item() != \
|
||||
mutable_in_channels.current_mask.sum().item():
|
||||
x = self.conv(x)
|
||||
if hasattr(self.conv, 'mutable_in_channels'
|
||||
) and self.conv.mutable_in_channels is not None:
|
||||
in_channels = self.conv.mutable_in_channels.current_mask.sum(
|
||||
).item()
|
||||
else:
|
||||
in_channels = self.conv.in_channels
|
||||
if hasattr(self.conv, 'mutable_out_channels'
|
||||
) and self.conv.mutable_out_channels is not None:
|
||||
out_channels = self.conv.mutable_out_channels.current_mask.sum(
|
||||
).item()
|
||||
else:
|
||||
out_channels = self.conv.out_channels
|
||||
|
||||
if in_channels != out_channels:
|
||||
x = self.conv(x)
|
||||
|
||||
return x
|
||||
|
||||
|
|
|
@ -31,12 +31,11 @@ def mutate_conv_module(
|
|||
|
||||
|
||||
def mutate_mobilenet_layer(mb_layer: MBBlock, mutable_in_channels,
|
||||
mutable_out_channels, mutable_se_channels,
|
||||
mutable_expand_value, mutable_kernel_size):
|
||||
mutable_out_channels, mutable_expand_ratio,
|
||||
mutable_kernel_size):
|
||||
"""Mutate MobileNet layers."""
|
||||
mb_layer.mutable_expand_value = mutable_expand_value
|
||||
mb_layer.derived_expand_channels = \
|
||||
mb_layer.mutable_expand_value * mutable_in_channels
|
||||
mutable_expand_ratio * mutable_in_channels
|
||||
|
||||
if mb_layer.with_expand_conv:
|
||||
mutate_conv_module(
|
||||
|
|
|
@ -213,11 +213,10 @@ class GroupMixin():
|
|||
def _check_valid_groups(self, alias2mutable_names: Dict[str, List[str]],
|
||||
name2mutable: Dict[str, BaseMutable],
|
||||
custom_group: List[List[str]]) -> None:
|
||||
|
||||
"""Check if all keys are legal."""
|
||||
aliases = [*alias2mutable_names.keys()]
|
||||
module_names = [*name2mutable.keys()]
|
||||
|
||||
# check if all keys are legal
|
||||
expanded_custom_group: List[str] = [
|
||||
_ for group in custom_group for _ in group
|
||||
]
|
||||
|
@ -261,8 +260,10 @@ class MutatorProtocol(Protocol): # pragma: no cover
|
|||
|
||||
|
||||
class OneShotSampleMixin:
|
||||
"""Sample mixin for one-shot mutators."""
|
||||
|
||||
def sample_choices(self: MutatorProtocol) -> Dict:
|
||||
"""Sample choices for each group in search_groups."""
|
||||
random_choices = dict()
|
||||
for group_id, modules in self.search_groups.items():
|
||||
random_choices[group_id] = modules[0].sample_choice()
|
||||
|
@ -270,6 +271,7 @@ class OneShotSampleMixin:
|
|||
return random_choices
|
||||
|
||||
def set_choices(self: MutatorProtocol, choices: Dict) -> None:
|
||||
"""Set choices for each group in search_groups."""
|
||||
for group_id, modules in self.search_groups.items():
|
||||
choice = choices[group_id]
|
||||
for module in modules:
|
||||
|
@ -279,6 +281,7 @@ class OneShotSampleMixin:
|
|||
class DynamicSampleMixin(OneShotSampleMixin):
|
||||
|
||||
def sample_choices(self: MutatorProtocol, kind: str = 'random') -> Dict:
|
||||
"""Sample choices for each group in search_groups."""
|
||||
random_choices = dict()
|
||||
for group_id, modules in self.search_groups.items():
|
||||
if kind == 'max':
|
||||
|
@ -291,6 +294,7 @@ class DynamicSampleMixin(OneShotSampleMixin):
|
|||
|
||||
@property
|
||||
def max_choice(self: MutatorProtocol) -> Dict:
|
||||
"""Get max choices for each group in search_groups."""
|
||||
max_choice = dict()
|
||||
for group_id, modules in self.search_groups.items():
|
||||
max_choice[group_id] = modules[0].max_choice
|
||||
|
@ -299,6 +303,7 @@ class DynamicSampleMixin(OneShotSampleMixin):
|
|||
|
||||
@property
|
||||
def min_choice(self: MutatorProtocol) -> Dict:
|
||||
"""Get min choices for each group in search_groups."""
|
||||
min_choice = dict()
|
||||
for group_id, modules in self.search_groups.items():
|
||||
min_choice[group_id] = modules[0].min_choice
|
||||
|
|
|
@ -7,6 +7,7 @@ from .value_mutator import ValueMutator
|
|||
|
||||
@MODELS.register_module()
|
||||
class DynamicValueMutator(ValueMutator, DynamicSampleMixin):
|
||||
"""Dynamic value mutator with type as `OneShotMutableValue`."""
|
||||
|
||||
@property
|
||||
def mutable_class_type(self):
|
||||
|
|
|
@ -73,7 +73,7 @@ class MetricPredictor:
|
|||
'cause the model of handler in predictor needs to be initialized.')
|
||||
|
||||
if self.initialize:
|
||||
model = export_fix_subnet(model)
|
||||
model, _ = export_fix_subnet(model)
|
||||
data = self.preprocess(np.array([self.model2vector(model)]))
|
||||
score = float(np.squeeze(self.handler.predict(data)))
|
||||
if metric.get(self.score_key_list[0], None):
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import copy
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from mmengine import fileio
|
||||
from mmengine.logging import print_log
|
||||
from torch import nn
|
||||
|
||||
from mmrazor.utils import FixMutable, ValidFixMutable
|
||||
|
@ -14,16 +14,17 @@ def _dynamic_to_static(model: nn.Module) -> None:
|
|||
from mmrazor.models.architectures.dynamic_ops import DynamicMixin
|
||||
|
||||
def traverse_children(module: nn.Module) -> None:
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, DynamicMixin):
|
||||
setattr(module, name, child.to_static_op())
|
||||
else:
|
||||
traverse_children(child)
|
||||
for name, mutable in module.items():
|
||||
if isinstance(mutable, DynamicMixin):
|
||||
module[name] = mutable.to_static_op()
|
||||
if hasattr(mutable, '_modules'):
|
||||
traverse_children(mutable._modules)
|
||||
|
||||
if isinstance(model, DynamicMixin):
|
||||
raise RuntimeError('Root model can not be dynamic op.')
|
||||
|
||||
traverse_children(model)
|
||||
if hasattr(model, '_modules'):
|
||||
traverse_children(model._modules)
|
||||
|
||||
|
||||
def load_fix_subnet(model: nn.Module,
|
||||
|
@ -48,68 +49,90 @@ def load_fix_subnet(model: nn.Module,
|
|||
from mmrazor.models.mutables import DerivedMutable, MutableChannelContainer
|
||||
from mmrazor.models.mutables.base_mutable import BaseMutable
|
||||
|
||||
def load_fix_module(module):
|
||||
"""Load fix module."""
|
||||
if getattr(module, 'alias', None):
|
||||
alias = module.alias
|
||||
assert alias in fix_mutable, \
|
||||
f'The alias {alias} is not in fix_modules, ' \
|
||||
'please check your `fix_mutable`.'
|
||||
# {chosen=xx, meta=xx)
|
||||
chosen = fix_mutable.get(alias, None)
|
||||
else:
|
||||
if prefix:
|
||||
mutable_name = name.lstrip(prefix)
|
||||
elif extra_prefix:
|
||||
mutable_name = extra_prefix + name
|
||||
else:
|
||||
mutable_name = name
|
||||
if mutable_name not in fix_mutable and not isinstance(
|
||||
module, MutableChannelContainer):
|
||||
raise RuntimeError(
|
||||
f'The module name {mutable_name} is not in '
|
||||
'fix_mutable, please check your `fix_mutable`.')
|
||||
# {chosen=xx, meta=xx)
|
||||
chosen = fix_mutable.get(mutable_name, None)
|
||||
|
||||
if not isinstance(chosen, DumpChosen):
|
||||
chosen = DumpChosen(**chosen)
|
||||
if not module.is_fixed:
|
||||
module.fix_chosen(chosen.chosen)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
# The format of `chosen`` is different for each type of mutable.
|
||||
# In the corresponding mutable, it will check whether the `chosen`
|
||||
# format is correct.
|
||||
if isinstance(module, (MutableChannelContainer, DerivedMutable)):
|
||||
if isinstance(module, (MutableChannelContainer)):
|
||||
continue
|
||||
if isinstance(module, BaseMutable):
|
||||
if not module.is_fixed:
|
||||
if getattr(module, 'alias', None):
|
||||
alias = module.alias
|
||||
assert alias in fix_mutable, \
|
||||
f'The alias {alias} is not in fix_modules, ' \
|
||||
'please check your `fix_mutable`.'
|
||||
# {chosen=xx, meta=xx)
|
||||
chosen = fix_mutable.get(alias, None)
|
||||
else:
|
||||
if prefix:
|
||||
mutable_name = name.lstrip(prefix)
|
||||
elif extra_prefix:
|
||||
mutable_name = extra_prefix + name
|
||||
else:
|
||||
mutable_name = name
|
||||
if mutable_name not in fix_mutable and not isinstance(
|
||||
module, (DerivedMutable, MutableChannelContainer)):
|
||||
raise RuntimeError(
|
||||
f'The module name {mutable_name} is not in '
|
||||
'fix_mutable, please check your `fix_mutable`.')
|
||||
# {chosen=xx, meta=xx)
|
||||
chosen = fix_mutable.get(mutable_name, None)
|
||||
|
||||
if not isinstance(chosen, DumpChosen):
|
||||
chosen = DumpChosen(**chosen)
|
||||
module.fix_chosen(chosen.chosen)
|
||||
if isinstance(module, BaseMutable):
|
||||
if isinstance(module, DerivedMutable):
|
||||
for source_mutable in module.source_mutables:
|
||||
load_fix_module(source_mutable)
|
||||
else:
|
||||
load_fix_module(module)
|
||||
|
||||
# convert dynamic op to static op
|
||||
_dynamic_to_static(model)
|
||||
|
||||
|
||||
def export_fix_subnet(model: nn.Module,
|
||||
dump_derived_mutable: bool = False) -> FixMutable:
|
||||
"""Export subnet that can be loaded by :func:`load_fix_subnet`."""
|
||||
if dump_derived_mutable:
|
||||
print_log(
|
||||
'Trying to dump information of all derived mutables, '
|
||||
'this might harm readability of the exported configurations.',
|
||||
level=logging.WARNING)
|
||||
def export_fix_subnet(
|
||||
model: nn.Module,
|
||||
slice_weight: bool = False) -> Tuple[FixMutable, Optional[Dict]]:
|
||||
"""Export subnet config with (optional) the sliced weight.
|
||||
|
||||
Args:
|
||||
slice_weight (bool): Whether to return the sliced subnet.
|
||||
Defaults to False.
|
||||
"""
|
||||
# Avoid circular import
|
||||
from mmrazor.models.mutables import DerivedMutable, MutableChannelContainer
|
||||
from mmrazor.models.mutables.base_mutable import BaseMutable
|
||||
|
||||
fix_subnet = dict()
|
||||
def module_dump_chosen(module, fix_subnet):
|
||||
if module.alias:
|
||||
fix_subnet[module.alias] = module.dump_chosen()
|
||||
else:
|
||||
fix_subnet[name] = module.dump_chosen()
|
||||
|
||||
fix_subnet: Dict[str, DumpChosen] = dict()
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, BaseMutable):
|
||||
if isinstance(module,
|
||||
(MutableChannelContainer,
|
||||
DerivedMutable)) and not dump_derived_mutable:
|
||||
if isinstance(module, MutableChannelContainer):
|
||||
continue
|
||||
|
||||
if module.alias:
|
||||
fix_subnet[module.alias] = module.dump_chosen()
|
||||
elif isinstance(module, DerivedMutable):
|
||||
for source_mutable in module.source_mutables:
|
||||
module_dump_chosen(source_mutable, fix_subnet)
|
||||
else:
|
||||
fix_subnet[name] = module.dump_chosen()
|
||||
module_dump_chosen(module, fix_subnet)
|
||||
|
||||
return fix_subnet
|
||||
if slice_weight:
|
||||
copied_model = copy.deepcopy(model)
|
||||
load_fix_subnet(copied_model, fix_subnet)
|
||||
|
||||
if next(copied_model.parameters()).is_cuda:
|
||||
copied_model.cuda()
|
||||
|
||||
return fix_subnet, copied_model
|
||||
|
||||
return fix_subnet, None
|
||||
|
|
|
@ -878,6 +878,7 @@ class DynamicMMBlock(nn.Module):
|
|||
self.with_attentive_shortcut = True
|
||||
self.in_channels = 24
|
||||
|
||||
self.first_out_channels_list = [16]
|
||||
self.first_conv = ConvModule(
|
||||
in_channels=3,
|
||||
out_channels=24,
|
||||
|
@ -888,8 +889,6 @@ class DynamicMMBlock(nn.Module):
|
|||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type='Swish'))
|
||||
|
||||
self.last_mutable = OneShotMutableChannel(num_channels=24, candidate_choices=[16, 24])
|
||||
|
||||
self.layers = []
|
||||
for i, (num_blocks, kernel_sizes, expand_ratios, num_channels) in \
|
||||
enumerate(zip(self.num_blocks_list, self.kernel_size_list,
|
||||
|
@ -971,13 +970,19 @@ class DynamicMMBlock(nn.Module):
|
|||
return dynamic_seq
|
||||
|
||||
def register_mutables(self):
|
||||
"""Mutate the BigNAS-style MobileNetV3."""
|
||||
OneShotMutableChannelUnit._register_channel_container(
|
||||
self, MutableChannelContainer)
|
||||
|
||||
# mutate the first conv
|
||||
mutate_conv_module(
|
||||
self.first_conv, mutable_out_channels=self.last_mutable)
|
||||
self.first_mutable_channels = OneShotMutableChannel(
|
||||
alias='backbone.first_channels',
|
||||
num_channels=max(self.first_out_channels_list),
|
||||
candidate_choices=self.first_out_channels_list)
|
||||
|
||||
mutate_conv_module(
|
||||
self.first_conv, mutable_out_channels=self.first_mutable_channels)
|
||||
|
||||
mid_mutable = self.first_mutable_channels
|
||||
# mutate the built mobilenet layers
|
||||
for i, layer in enumerate(self.layers[:-1]):
|
||||
num_blocks = self.num_blocks_list[i]
|
||||
|
@ -985,46 +990,48 @@ class DynamicMMBlock(nn.Module):
|
|||
expand_ratios = self.expand_ratio_list[i]
|
||||
out_channels = self.num_channels_list[i]
|
||||
|
||||
mutable_kernel_size = OneShotMutableValue(
|
||||
value_list=kernel_sizes, default_value=max(kernel_sizes))
|
||||
mutable_expand_value = OneShotMutableValue(
|
||||
value_list=expand_ratios, default_value=max(expand_ratios))
|
||||
prefix = 'backbone.layers.' + str(i + 1) + '.'
|
||||
|
||||
mutable_out_channels = OneShotMutableChannel(
|
||||
num_channels=max(out_channels), candidate_choices=out_channels)
|
||||
alias=prefix + 'out_channels',
|
||||
candidate_choices=out_channels,
|
||||
num_channels=max(out_channels))
|
||||
|
||||
se_ratios = [i / 4 for i in expand_ratios]
|
||||
mutable_se_channels = OneShotMutableValue(
|
||||
value_list=se_ratios, default_value=max(se_ratios))
|
||||
mutable_kernel_size = OneShotMutableValue(
|
||||
alias=prefix + 'kernel_size', value_list=kernel_sizes)
|
||||
|
||||
for k in range(max(self.num_blocks_list[i])):
|
||||
mutate_mobilenet_layer(layer[k], self.last_mutable,
|
||||
mutable_out_channels,
|
||||
mutable_se_channels,
|
||||
mutable_expand_value,
|
||||
mutable_kernel_size)
|
||||
self.last_mutable = mutable_out_channels
|
||||
mutable_expand_ratio = OneShotMutableValue(
|
||||
alias=prefix + 'expand_ratio', value_list=expand_ratios)
|
||||
|
||||
mutable_depth = OneShotMutableValue(
|
||||
value_list=num_blocks, default_value=max(num_blocks))
|
||||
alias=prefix + 'depth', value_list=num_blocks)
|
||||
layer.register_mutable_attr('depth', mutable_depth)
|
||||
|
||||
mutable_out_channels = OneShotMutableChannel(
|
||||
for k in range(max(self.num_blocks_list[i])):
|
||||
mutate_mobilenet_layer(layer[k], mid_mutable,
|
||||
mutable_out_channels,
|
||||
mutable_expand_ratio,
|
||||
mutable_kernel_size)
|
||||
mid_mutable = mutable_out_channels
|
||||
|
||||
self.last_mutable_channels = OneShotMutableChannel(
|
||||
alias='backbone.last_channels',
|
||||
num_channels=self.out_channels,
|
||||
candidate_choices=self.last_out_channels_list)
|
||||
|
||||
last_mutable_expand_value = OneShotMutableValue(
|
||||
value_list=self.last_expand_ratio_list,
|
||||
default_value=max(self.last_expand_ratio_list))
|
||||
derived_expand_channels = self.last_mutable * last_mutable_expand_value
|
||||
|
||||
derived_expand_channels = mid_mutable * last_mutable_expand_value
|
||||
mutate_conv_module(
|
||||
self.layers[-1].final_expand_layer,
|
||||
mutable_in_channels=self.last_mutable,
|
||||
mutable_in_channels=mid_mutable,
|
||||
mutable_out_channels=derived_expand_channels)
|
||||
mutate_conv_module(
|
||||
self.layers[-1].feature_mix_layer,
|
||||
mutable_in_channels=derived_expand_channels,
|
||||
mutable_out_channels=mutable_out_channels)
|
||||
|
||||
self.last_mutable = mutable_out_channels
|
||||
mutable_out_channels=self.last_mutable_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.first_conv(x)
|
||||
|
|
|
@ -2,3 +2,5 @@ mutable1:
|
|||
chosen: conv1
|
||||
mutable2:
|
||||
chosen: conv2
|
||||
mutable3.0.kernel_size:
|
||||
chosen: 3
|
||||
|
|
|
@ -117,10 +117,10 @@ class TestAutoSlim(TestCase):
|
|||
assert losses['max_subnet.loss'] > 0
|
||||
assert losses['min_subnet.loss'] > 0
|
||||
assert losses['min_subnet.loss_kl'] + 1e-5 > 0
|
||||
assert losses['random_subnet_0.loss'] > 0
|
||||
assert losses['random_subnet_0.loss_kl'] + 1e-5 > 0
|
||||
assert losses['random_subnet_1.loss'] > 0
|
||||
assert losses['random_subnet_1.loss_kl'] + 1e-5 > 0
|
||||
assert losses['random0_subnet.loss'] > 0
|
||||
assert losses['random0_subnet.loss_kl'] + 1e-5 > 0
|
||||
assert losses['random1_subnet.loss'] > 0
|
||||
assert losses['random1_subnet.loss_kl'] + 1e-5 > 0
|
||||
|
||||
assert algo._optim_wrapper_count_status_reinitialized
|
||||
assert optim_wrapper._inner_count == 4
|
||||
|
@ -142,13 +142,13 @@ class TestAutoSlim(TestCase):
|
|||
mutator_cfg: MUTATOR_TYPE = MUTATOR_CFG,
|
||||
distiller_cfg: DISTILLER_TYPE = DISTILLER_CFG,
|
||||
architecture_cfg: Dict = ARCHITECTURE_CFG,
|
||||
num_samples: int = 2) -> AutoSlim:
|
||||
num_random_samples: int = 2) -> AutoSlim:
|
||||
model = AutoSlim(
|
||||
mutator=mutator_cfg,
|
||||
distiller=distiller_cfg,
|
||||
architecture=architecture_cfg,
|
||||
data_preprocessor=ToyDataPreprocessor(),
|
||||
num_samples=num_samples)
|
||||
num_random_samples=num_random_samples)
|
||||
model.to(self.device)
|
||||
|
||||
return model
|
||||
|
@ -173,12 +173,12 @@ class TestAutoSlimDDP(TestAutoSlim):
|
|||
mutator_cfg: MUTATOR_TYPE = MUTATOR_CFG,
|
||||
distiller_cfg: DISTILLER_TYPE = DISTILLER_CFG,
|
||||
architecture_cfg: Dict = ARCHITECTURE_CFG,
|
||||
num_samples: int = 2) -> AutoSlim:
|
||||
num_random_samples: int = 2) -> AutoSlim:
|
||||
model = super().prepare_model(
|
||||
mutator_cfg=mutator_cfg,
|
||||
distiller_cfg=distiller_cfg,
|
||||
architecture_cfg=architecture_cfg,
|
||||
num_samples=num_samples)
|
||||
num_random_samples=num_random_samples)
|
||||
|
||||
return AutoSlimDDP(module=model, find_unused_parameters=True)
|
||||
|
||||
|
|
|
@ -81,7 +81,7 @@ ARCHITECTURE_CFG = dict(
|
|||
label_smooth_val=0.1,
|
||||
loss_weight=1.0),
|
||||
topk=(1, 5)),
|
||||
connect_head=dict(connect_with_backbone='backbone.last_mutable'),
|
||||
connect_head=dict(connect_with_backbone='backbone.last_mutable_channels'),
|
||||
)
|
||||
|
||||
ALGORITHM_CFG = dict(
|
||||
|
|
|
@ -72,7 +72,7 @@ def test_attentive_mobilenet_mutable() -> None:
|
|||
elif isinstance(module, DynamicSequential):
|
||||
assert isinstance(module.mutable_depth, OneShotMutableValue)
|
||||
|
||||
assert backbone.last_mutable.num_channels == max(out_channels[-1])
|
||||
assert backbone.last_mutable_channels.num_channels == max(out_channels[-1])
|
||||
|
||||
|
||||
def test_attentive_mobilenet_train() -> None:
|
||||
|
|
|
@ -59,7 +59,7 @@ class TestDynamicConv2d(TestCase):
|
|||
with pytest.raises(RuntimeError):
|
||||
_ = d_conv2d.to_static_op()
|
||||
|
||||
fix_mutables = export_fix_subnet(d_conv2d)
|
||||
fix_mutables = export_fix_subnet(d_conv2d)[0]
|
||||
with pytest.raises(RuntimeError):
|
||||
load_fix_subnet(d_conv2d, fix_mutables)
|
||||
fix_dynamic_op(d_conv2d, fix_mutables)
|
||||
|
@ -126,7 +126,7 @@ def test_dynamic_conv2d(bias: bool, dynamic_class: Type[nn.Conv2d]) -> None:
|
|||
out1 = d_conv2d(x)
|
||||
assert out1.size(1) == 4
|
||||
|
||||
fix_mutables = export_fix_subnet(d_conv2d)
|
||||
fix_mutables = export_fix_subnet(d_conv2d)[0]
|
||||
with pytest.raises(RuntimeError):
|
||||
load_fix_subnet(d_conv2d, fix_mutables)
|
||||
fix_dynamic_op(d_conv2d, fix_mutables)
|
||||
|
@ -180,7 +180,7 @@ def test_dynamic_conv2d_mutable_single_channels(
|
|||
with pytest.raises(RuntimeError):
|
||||
_ = d_conv2d.to_static_op()
|
||||
|
||||
fix_mutables = export_fix_subnet(d_conv2d)
|
||||
fix_mutables = export_fix_subnet(d_conv2d)[0]
|
||||
with pytest.raises(RuntimeError):
|
||||
load_fix_subnet(d_conv2d, fix_mutables)
|
||||
fix_dynamic_op(d_conv2d, fix_mutables)
|
||||
|
@ -239,7 +239,7 @@ def test_kernel_dynamic_conv2d(dynamic_class: Type[nn.Conv2d],
|
|||
out1 = d_conv2d(x)
|
||||
assert out1.size(1) == 8
|
||||
|
||||
fix_mutables = export_fix_subnet(d_conv2d)
|
||||
fix_mutables = export_fix_subnet(d_conv2d)[0]
|
||||
with pytest.raises(RuntimeError):
|
||||
load_fix_subnet(d_conv2d, fix_mutables)
|
||||
fix_dynamic_op(d_conv2d, fix_mutables)
|
||||
|
|
|
@ -49,7 +49,7 @@ def test_dynamic_linear(bias) -> None:
|
|||
with pytest.raises(RuntimeError):
|
||||
_ = d_linear.to_static_op()
|
||||
|
||||
fix_mutables = export_fix_subnet(d_linear)
|
||||
fix_mutables = export_fix_subnet(d_linear)[0]
|
||||
with pytest.raises(RuntimeError):
|
||||
load_fix_subnet(d_linear, fix_mutables)
|
||||
fix_dynamic_op(d_linear, fix_mutables)
|
||||
|
@ -100,7 +100,7 @@ def test_dynamic_linear_mutable_single_features(
|
|||
with pytest.raises(RuntimeError):
|
||||
_ = d_linear.to_static_op()
|
||||
|
||||
fix_mutables = export_fix_subnet(d_linear)
|
||||
fix_mutables = export_fix_subnet(d_linear)[0]
|
||||
with pytest.raises(RuntimeError):
|
||||
load_fix_subnet(d_linear, fix_mutables)
|
||||
fix_dynamic_op(d_linear, fix_mutables)
|
||||
|
|
|
@ -59,7 +59,7 @@ def test_dynamic_bn(dynamic_class: Type[nn.modules.batchnorm._BatchNorm],
|
|||
out1 = d_bn(x)
|
||||
assert out1.size(1) == 8
|
||||
|
||||
fix_mutables = export_fix_subnet(d_bn)
|
||||
fix_mutables = export_fix_subnet(d_bn)[0]
|
||||
with pytest.raises(RuntimeError):
|
||||
load_fix_subnet(d_bn, fix_mutables)
|
||||
fix_dynamic_op(d_bn, fix_mutables)
|
||||
|
|
|
@ -56,8 +56,10 @@ class TestValueMutator(unittest.TestCase):
|
|||
for each_mutables in module.source_mutables:
|
||||
if isinstance(each_mutables, MutableValue):
|
||||
mutable_value_space.append(each_mutables)
|
||||
assert len(
|
||||
value_mutator.search_groups) == len(mutable_value_space)
|
||||
count = 0
|
||||
for values in value_mutator.search_groups.values():
|
||||
count += len(values)
|
||||
assert count == len(mutable_value_space)
|
||||
|
||||
x = torch.rand([2, 3, 224, 224])
|
||||
y = model(x)
|
||||
|
|
|
@ -5,6 +5,7 @@ import pytest
|
|||
import torch.nn as nn
|
||||
|
||||
from mmrazor.models import * # noqa:F403,F401
|
||||
from mmrazor.models.architectures.dynamic_ops import BigNasConv2d
|
||||
from mmrazor.models.mutables import OneShotMutableOP, OneShotMutableValue
|
||||
from mmrazor.registry import MODELS
|
||||
from mmrazor.structures import export_fix_subnet, load_fix_subnet
|
||||
|
@ -30,10 +31,17 @@ class MockModel(nn.Module):
|
|||
|
||||
self.mutable1 = OneShotMutableOP(convs1)
|
||||
self.mutable2 = OneShotMutableOP(convs2)
|
||||
self.mutable3 = nn.Sequential(BigNasConv2d(16, 16, 5))
|
||||
|
||||
mutable_kernel_size = OneShotMutableValue(
|
||||
alias='mutable3.0.kernel_size', value_list=[3, 5])
|
||||
self.mutable3[0].register_mutable_attr('kernel_size',
|
||||
mutable_kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.mutable1(x)
|
||||
x = self.mutable2(x)
|
||||
x = self.mutable3(x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -63,6 +71,9 @@ class TestFixSubnet(TestCase):
|
|||
'mutable2': {
|
||||
'chosen': 'conv2'
|
||||
},
|
||||
'mutable3.0.kernel_size': {
|
||||
'chosen': 3
|
||||
}
|
||||
}
|
||||
|
||||
model = MockModel()
|
||||
|
@ -90,53 +101,50 @@ class TestFixSubnet(TestCase):
|
|||
'mutable2': {
|
||||
'chosen': 'conv2'
|
||||
},
|
||||
'mutable3.0.kernel_size': {
|
||||
'chosen': 3
|
||||
}
|
||||
}
|
||||
|
||||
model = MockModel()
|
||||
load_fix_subnet(model, fix_subnet)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
exported_fix_subnet: FixMutable = export_fix_subnet(model)
|
||||
exported_fix_subnet: FixMutable = export_fix_subnet(model)[0]
|
||||
|
||||
model = MockModel()
|
||||
model.mutable1.current_choice = 'conv1'
|
||||
model.mutable2.current_choice = 'conv2'
|
||||
exported_fix_subnet = export_fix_subnet(model)
|
||||
model.mutable3[0].mutable_attrs.kernel_size.current_choice = 3
|
||||
exported_fix_subnet = export_fix_subnet(model)[0]
|
||||
|
||||
mutable1_dump_chosen = exported_fix_subnet['mutable1']
|
||||
mutable2_dump_chosen = exported_fix_subnet['mutable2']
|
||||
mutable3_0_ks_chosen = exported_fix_subnet['mutable3.0.kernel_size']
|
||||
|
||||
mutable1_chosen_dict = dict(chosen=mutable1_dump_chosen.chosen)
|
||||
mutable2_chosen_dict = dict(chosen=mutable2_dump_chosen.chosen)
|
||||
mutable3_0_ks_chosen_dict = dict(chosen=mutable3_0_ks_chosen.chosen)
|
||||
|
||||
exported_fix_subnet['mutable1'] = mutable1_chosen_dict
|
||||
exported_fix_subnet['mutable2'] = mutable2_chosen_dict
|
||||
exported_fix_subnet['mutable3.0.kernel_size'] = \
|
||||
mutable3_0_ks_chosen_dict
|
||||
self.assertDictEqual(fix_subnet, exported_fix_subnet)
|
||||
|
||||
def test_export_fix_subnet_with_derived_mutable(self) -> None:
|
||||
model = MockModelWithDerivedMutable()
|
||||
fix_subnet = export_fix_subnet(model)
|
||||
fix_subnet = export_fix_subnet(model)[0]
|
||||
self.assertDictEqual(
|
||||
fix_subnet, {'source_mutable': model.source_mutable.dump_chosen()})
|
||||
fix_subnet, {
|
||||
'source_mutable': model.source_mutable.dump_chosen(),
|
||||
'derived_mutable': model.source_mutable.dump_chosen()
|
||||
})
|
||||
|
||||
fix_subnet['source_mutable'] = dict(
|
||||
fix_subnet['source_mutable']._asdict())
|
||||
fix_subnet['source_mutable']['chosen'] = 4
|
||||
load_fix_subnet(model, fix_subnet)
|
||||
|
||||
assert model.source_mutable.current_choice == 4
|
||||
assert model.derived_mutable.current_choice == 8
|
||||
|
||||
model = MockModelWithDerivedMutable()
|
||||
fix_subnet = export_fix_subnet(model, dump_derived_mutable=True)
|
||||
self.assertDictEqual(
|
||||
fix_subnet, {
|
||||
'source_mutable': model.source_mutable.dump_chosen(),
|
||||
'derived_mutable': model.derived_mutable.dump_chosen()
|
||||
})
|
||||
|
||||
fix_subnet['source_mutable'] = dict(
|
||||
fix_subnet['source_mutable']._asdict())
|
||||
fix_subnet['source_mutable']['chosen'] = 2
|
||||
load_fix_subnet(model, fix_subnet)
|
||||
assert model.source_mutable.current_choice == 2
|
||||
assert model.derived_mutable.current_choice == 4
|
||||
|
|
|
@ -11,7 +11,7 @@ from torch.nn import Conv2d, Module, Parameter
|
|||
from mmrazor.models import OneShotMutableModule, ResourceEstimator
|
||||
from mmrazor.models.task_modules.estimators.counters import BaseCounter
|
||||
from mmrazor.registry import MODELS, TASK_UTILS
|
||||
from mmrazor.structures import export_fix_subnet, load_fix_subnet
|
||||
from mmrazor.structures import export_fix_subnet
|
||||
|
||||
_FIRST_STAGE_MUTABLE = dict(
|
||||
type='OneShotMutableOP',
|
||||
|
@ -216,10 +216,9 @@ class TestResourceEstimator(TestCase):
|
|||
flops_count = results['flops']
|
||||
params_count = results['params']
|
||||
|
||||
fix_subnet = export_fix_subnet(model)
|
||||
load_fix_subnet(copied_model, fix_subnet)
|
||||
_, sliced_model = export_fix_subnet(model, slice_weight=True)
|
||||
subnet_results = estimator.estimate(
|
||||
model=copied_model, flops_params_cfg=flops_params_cfg)
|
||||
model=sliced_model, flops_params_cfg=flops_params_cfg)
|
||||
subnet_flops_count = subnet_results['flops']
|
||||
subnet_params_count = subnet_results['params']
|
||||
|
||||
|
|
|
@ -38,6 +38,16 @@ class ToyDataset(Dataset):
|
|||
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||
|
||||
|
||||
class ToyModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.architecture = nn.Conv2d(1, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.architecture(x)
|
||||
|
||||
|
||||
class ToyRunner:
|
||||
|
||||
@property
|
||||
|
@ -57,7 +67,7 @@ class ToyRunner:
|
|||
pass
|
||||
|
||||
def model(self):
|
||||
return nn.Conv2d
|
||||
return ToyModel()
|
||||
|
||||
def logger(self):
|
||||
pass
|
||||
|
@ -114,8 +124,8 @@ class TestEvolutionSearchLoop(TestCase):
|
|||
self.assertIsInstance(loop, EvolutionSearchLoop)
|
||||
self.assertEqual(loop.candidates, fake_candidates)
|
||||
|
||||
@patch('mmrazor.engine.runner.utils.check.load_fix_subnet')
|
||||
@patch('mmrazor.engine.runner.utils.check.export_fix_subnet')
|
||||
@patch('mmrazor.structures.subnet.fix_subnet.load_fix_subnet')
|
||||
@patch('mmrazor.structures.subnet.fix_subnet.export_fix_subnet')
|
||||
@patch('mmrazor.models.task_modules.estimators.resource_estimator.'
|
||||
'get_model_flops_params')
|
||||
def test_run_epoch(self, flops_params, mock_export_fix_subnet,
|
||||
|
@ -131,6 +141,7 @@ class TestEvolutionSearchLoop(TestCase):
|
|||
self.runner.work_dir = self.temp_dir
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
loop.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
||||
mock_export_fix_subnet.return_value = (fake_subnet, self.runner.model)
|
||||
load_status.return_value = True
|
||||
flops_params.return_value = 0, 0
|
||||
loop.run_epoch()
|
||||
|
@ -159,13 +170,12 @@ class TestEvolutionSearchLoop(TestCase):
|
|||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
loop.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
||||
flops_params.return_value = (50., 1)
|
||||
mock_export_fix_subnet.return_value = fake_subnet
|
||||
loop.run_epoch()
|
||||
self.assertEqual(len(loop.candidates), 4)
|
||||
self.assertEqual(len(loop.top_k_candidates), 2)
|
||||
self.assertEqual(loop._epoch, 1)
|
||||
|
||||
@patch('mmrazor.engine.runner.utils.check.export_fix_subnet')
|
||||
@patch('mmrazor.structures.subnet.fix_subnet.export_fix_subnet')
|
||||
@patch('mmrazor.models.task_modules.estimators.resource_estimator.'
|
||||
'get_model_flops_params')
|
||||
def test_run_loop(self, mock_flops, mock_export_fix_subnet):
|
||||
|
@ -179,6 +189,7 @@ class TestEvolutionSearchLoop(TestCase):
|
|||
loop._epoch = 1
|
||||
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
mock_export_fix_subnet.return_value = (fake_subnet, self.runner.model)
|
||||
self.runner.work_dir = self.temp_dir
|
||||
loop.update_candidate_pool = MagicMock()
|
||||
loop.val_candidate_pool = MagicMock()
|
||||
|
@ -197,7 +208,7 @@ class TestEvolutionSearchLoop(TestCase):
|
|||
MagicMock(return_value=crossover_candidates)
|
||||
loop.candidates = Candidates([fake_subnet] * 4)
|
||||
mock_flops.return_value = (0.5, 101)
|
||||
mock_export_fix_subnet.return_value = fake_subnet
|
||||
torch.save = MagicMock()
|
||||
loop.run()
|
||||
assert os.path.exists(
|
||||
os.path.join(self.temp_dir, 'best_fix_subnet.yaml'))
|
||||
|
@ -271,13 +282,12 @@ class TestEvolutionSearchLoopWithPredictor(TestCase):
|
|||
self.assertIsInstance(loop, EvolutionSearchLoop)
|
||||
self.assertEqual(loop.candidates, fake_candidates)
|
||||
|
||||
@patch('mmrazor.engine.runner.utils.check.load_fix_subnet')
|
||||
@patch('mmrazor.engine.runner.utils.check.export_fix_subnet')
|
||||
@patch('mmrazor.structures.subnet.fix_subnet.load_fix_subnet')
|
||||
@patch('mmrazor.structures.subnet.fix_subnet.export_fix_subnet')
|
||||
@patch('mmrazor.models.task_modules.estimators.resource_estimator.'
|
||||
'get_model_flops_params')
|
||||
def test_run_epoch(self, flops_params, mock_export_fix_subnet,
|
||||
load_status):
|
||||
# test_run_epoch: distributed == False
|
||||
loop_cfg = copy.deepcopy(self.train_cfg)
|
||||
loop_cfg.runner = self.runner
|
||||
loop_cfg.dataloader = self.dataloader
|
||||
|
@ -288,6 +298,7 @@ class TestEvolutionSearchLoopWithPredictor(TestCase):
|
|||
self.runner.work_dir = self.temp_dir
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
loop.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
||||
mock_export_fix_subnet.return_value = (fake_subnet, self.runner.model)
|
||||
load_status.return_value = True
|
||||
flops_params.return_value = 0, 0
|
||||
loop.run_epoch()
|
||||
|
@ -316,13 +327,12 @@ class TestEvolutionSearchLoopWithPredictor(TestCase):
|
|||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
loop.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
||||
flops_params.return_value = (50., 1)
|
||||
mock_export_fix_subnet.return_value = fake_subnet
|
||||
loop.run_epoch()
|
||||
self.assertEqual(len(loop.candidates), 4)
|
||||
self.assertEqual(len(loop.top_k_candidates), 2)
|
||||
self.assertEqual(loop._epoch, 1)
|
||||
|
||||
@patch('mmrazor.engine.runner.utils.check.export_fix_subnet')
|
||||
@patch('mmrazor.structures.subnet.fix_subnet.export_fix_subnet')
|
||||
@patch('mmrazor.models.task_modules.predictor.metric_predictor.'
|
||||
'MetricPredictor.model2vector')
|
||||
@patch('mmrazor.models.task_modules.estimators.resource_estimator.'
|
||||
|
@ -340,6 +350,7 @@ class TestEvolutionSearchLoopWithPredictor(TestCase):
|
|||
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
loop.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
||||
mock_export_fix_subnet.return_value = (fake_subnet, self.runner.model)
|
||||
|
||||
self.runner.work_dir = self.temp_dir
|
||||
loop.update_candidate_pool = MagicMock()
|
||||
|
@ -360,10 +371,9 @@ class TestEvolutionSearchLoopWithPredictor(TestCase):
|
|||
loop.candidates = Candidates([fake_subnet] * 4)
|
||||
|
||||
mock_flops.return_value = (0.5, 101)
|
||||
mock_export_fix_subnet.return_value = fake_subnet
|
||||
mock_model2vector.return_value = dict(
|
||||
normal_vector=[0, 1], onehot_vector=[0, 1, 0, 1])
|
||||
|
||||
torch.save = MagicMock()
|
||||
loop.run()
|
||||
assert os.path.exists(
|
||||
os.path.join(self.temp_dir, 'best_fix_subnet.yaml'))
|
||||
|
|
|
@ -9,26 +9,42 @@ def parse_args():
|
|||
parser = argparse.ArgumentParser(
|
||||
description='Process a checkpoint to be published')
|
||||
parser.add_argument('checkpoint', help='input checkpoint filename')
|
||||
parser.add_argument('--depth', nargs='+', type=int, help='layer depth')
|
||||
parser.add_argument(
|
||||
'--inplace', action='store_true', help='replace origin ckpt')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def block2layer_index_convert(layer_depth):
|
||||
"""Build index_table from OFA blocks to MMRazor layers."""
|
||||
index_table = dict()
|
||||
i = 0
|
||||
first_index = 1
|
||||
second_index = 0
|
||||
for k in layer_depth:
|
||||
for _ in range(k):
|
||||
index_table[str(i)] = str(first_index) + '.' + str(second_index)
|
||||
i += 1
|
||||
second_index += 1
|
||||
second_index = 0
|
||||
first_index += 1
|
||||
|
||||
return index_table
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
||||
new_state_dict = dict()
|
||||
|
||||
index_table = block2layer_index_convert(args.depth)
|
||||
|
||||
for key, value in checkpoint['state_dict'].items():
|
||||
if 'blocks.0.' in key:
|
||||
new_key = key.replace('blocks.0', 'layer1.0')
|
||||
elif 'blocks' in key:
|
||||
for i in range(1, 21):
|
||||
if 'blocks.' + str(i) in key:
|
||||
new_key = key.replace(
|
||||
'blocks.' + str(i),
|
||||
'layer' + str(i // 4 + 1) + '.' + str((i + 3) % 4))
|
||||
if 'blocks' in key:
|
||||
index = key.split('.')[1]
|
||||
new_key = key.replace('blocks.' + index,
|
||||
'layer' + index_table[index])
|
||||
else:
|
||||
new_key = key
|
||||
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
import time
|
||||
|
||||
import torch
|
||||
from mmengine.config import Config
|
||||
from mmengine.runner import Runner
|
||||
|
||||
from mmrazor.structures.subnet import load_fix_subnet
|
||||
from mmrazor.utils import register_all_modules
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Process a NAS supernet checkpoint to be converted')
|
||||
parser.add_argument('config', help='NAS model config file path')
|
||||
parser.add_argument('checkpoint', help='supernet checkpoint file path')
|
||||
parser.add_argument('yaml', help='YAML with subnet settings file path')
|
||||
parser.add_argument(
|
||||
'--launcher',
|
||||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||
default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
register_all_modules(False)
|
||||
args = parse_args()
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
cfg.launcher = args.launcher
|
||||
|
||||
cfg.load_from = args.checkpoint
|
||||
cfg.work_dir = '/'.join(args.checkpoint.split('/')[:-1])
|
||||
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
load_fix_subnet(runner.model, args.yaml)
|
||||
|
||||
timestamp_subnet = time.strftime('%Y%m%d_%H%M', time.localtime())
|
||||
model_name = f'subnet_{timestamp_subnet}.pth'
|
||||
save_path = osp.join(runner.work_dir, model_name)
|
||||
torch.save({
|
||||
'state_dict': runner.model.state_dict(),
|
||||
'meta': {}
|
||||
}, save_path)
|
||||
runner.logger.info(f'Successful converted. Saved in {save_path}.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue