mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* mmcv.Config -> mmengine Config * support mmedit part * add rewriter for BaseEditModels * fix visualizer * mmedit visualization * remove unused code * fix realesrgan * fix trt * support MultiTestLoop; rewriter fix mmediting bugs; fix ut * fix uts * fix mmedit sdk * fix regression test(part) * fix torchscript * part of fix regression test * fix checkenv.py * fix test.py for mmedit2.0 * support for mmedit * fix regression_test * fix check copyright ci * fix isort * fix docformatter * fix yapf * fix tests * fix sdk after 1040 * add a file for ut * fix docformatter * fix export info * fix super_resolution * fix test.py * stage configs * remove unused code * remove rewriter of multitestloop * fix yapf
255 lines
7.2 KiB
Python
255 lines
7.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
default_scope = 'mmedit'
|
|
save_dir = './work_dirs'
|
|
|
|
default_hooks = dict(
|
|
timer=dict(type='IterTimerHook'),
|
|
logger=dict(type='LoggerHook', interval=100),
|
|
param_scheduler=dict(type='ParamSchedulerHook'),
|
|
checkpoint=dict(
|
|
type='CheckpointHook',
|
|
interval=5000,
|
|
out_dir=save_dir,
|
|
by_epoch=False,
|
|
max_keep_ckpts=10,
|
|
save_best='PSNR',
|
|
rule='greater',
|
|
),
|
|
sampler_seed=dict(type='DistSamplerSeedHook'),
|
|
)
|
|
|
|
env_cfg = dict(
|
|
cudnn_benchmark=False,
|
|
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=4),
|
|
dist_cfg=dict(backend='nccl'),
|
|
)
|
|
|
|
log_level = 'INFO'
|
|
log_processor = dict(type='LogProcessor', window_size=100, by_epoch=False)
|
|
|
|
load_from = None
|
|
resume = False
|
|
experiment_name = 'srcnn_x4k915_1xb16-1000k_div2k'
|
|
work_dir = f'./work_dirs/{experiment_name}'
|
|
save_dir = './work_dirs/'
|
|
|
|
scale = 4
|
|
# model settings
|
|
model = dict(
|
|
type='BaseEditModel',
|
|
generator=dict(
|
|
type='SRCNNNet',
|
|
channels=(3, 64, 32, 3),
|
|
kernel_sizes=(9, 1, 5),
|
|
upscale_factor=scale),
|
|
pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'),
|
|
train_cfg=dict(),
|
|
test_cfg=dict(metrics=['PSNR'], crop_border=scale),
|
|
data_preprocessor=dict(
|
|
type='EditDataPreprocessor',
|
|
mean=[0., 0., 0.],
|
|
std=[255., 255., 255.],
|
|
))
|
|
|
|
train_pipeline = [
|
|
dict(
|
|
type='LoadImageFromFile',
|
|
key='img',
|
|
color_type='color',
|
|
channel_order='rgb',
|
|
imdecode_backend='cv2'),
|
|
dict(
|
|
type='LoadImageFromFile',
|
|
key='gt',
|
|
color_type='color',
|
|
channel_order='rgb',
|
|
imdecode_backend='cv2'),
|
|
dict(type='SetValues', dictionary=dict(scale=scale)),
|
|
dict(type='PairedRandomCrop', gt_patch_size=128),
|
|
dict(
|
|
type='Flip',
|
|
keys=['img', 'gt'],
|
|
flip_ratio=0.5,
|
|
direction='horizontal'),
|
|
dict(
|
|
type='Flip', keys=['img', 'gt'], flip_ratio=0.5, direction='vertical'),
|
|
dict(type='RandomTransposeHW', keys=['img', 'gt'], transpose_ratio=0.5),
|
|
dict(type='ToTensor', keys=['img', 'gt']),
|
|
dict(type='PackEditInputs')
|
|
]
|
|
val_pipeline = [
|
|
dict(
|
|
type='LoadImageFromFile',
|
|
key='img',
|
|
color_type='color',
|
|
channel_order='rgb',
|
|
imdecode_backend='cv2'),
|
|
dict(
|
|
type='LoadImageFromFile',
|
|
key='gt',
|
|
color_type='color',
|
|
channel_order='rgb',
|
|
imdecode_backend='cv2'),
|
|
dict(type='ToTensor', keys=['img', 'gt']),
|
|
dict(type='PackEditInputs')
|
|
]
|
|
|
|
# dataset settings
|
|
dataset_type = 'BasicImageDataset'
|
|
data_root = 'data'
|
|
|
|
train_dataloader = dict(
|
|
num_workers=4,
|
|
persistent_workers=False,
|
|
sampler=dict(type='InfiniteSampler', shuffle=True),
|
|
dataset=dict(
|
|
type=dataset_type,
|
|
ann_file='meta_info_DIV2K800sub_GT.txt',
|
|
metainfo=dict(dataset_type='div2k', task_name='sisr'),
|
|
data_root=data_root + '/DIV2K',
|
|
data_prefix=dict(
|
|
img='DIV2K_train_LR_bicubic/X4_sub', gt='DIV2K_train_HR_sub'),
|
|
filename_tmpl=dict(img='{}', gt='{}'),
|
|
pipeline=train_pipeline))
|
|
|
|
val_dataloader = dict(
|
|
num_workers=4,
|
|
persistent_workers=False,
|
|
drop_last=False,
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
dataset=dict(
|
|
type=dataset_type,
|
|
metainfo=dict(dataset_type='set5', task_name='sisr'),
|
|
data_root=data_root + '/Set5',
|
|
data_prefix=dict(img='LRbicx4', gt='GTmod12'),
|
|
pipeline=val_pipeline))
|
|
|
|
val_evaluator = [
|
|
dict(type='MAE'),
|
|
dict(type='PSNR', crop_border=scale),
|
|
dict(type='SSIM', crop_border=scale),
|
|
]
|
|
|
|
train_cfg = dict(
|
|
type='IterBasedTrainLoop', max_iters=1000000, val_interval=5000)
|
|
val_cfg = dict(type='ValLoop')
|
|
|
|
# optimizer
|
|
optim_wrapper = dict(
|
|
constructor='DefaultOptimWrapperConstructor',
|
|
type='OptimWrapper',
|
|
optimizer=dict(type='Adam', lr=2e-4, betas=(0.9, 0.99)))
|
|
|
|
# learning policy
|
|
param_scheduler = dict(
|
|
type='CosineRestartLR',
|
|
by_epoch=False,
|
|
periods=[250000, 250000, 250000, 250000],
|
|
restart_weights=[1, 1, 1, 1],
|
|
eta_min=1e-7)
|
|
|
|
default_hooks = dict(
|
|
checkpoint=dict(
|
|
type='CheckpointHook',
|
|
interval=5000,
|
|
save_optimizer=True,
|
|
by_epoch=False,
|
|
out_dir=save_dir,
|
|
),
|
|
timer=dict(type='IterTimerHook'),
|
|
logger=dict(type='LoggerHook', interval=100),
|
|
param_scheduler=dict(type='ParamSchedulerHook'),
|
|
sampler_seed=dict(type='DistSamplerSeedHook'),
|
|
)
|
|
|
|
test_pipeline = [
|
|
dict(
|
|
type='LoadImageFromFile',
|
|
key='img',
|
|
color_type='color',
|
|
channel_order='rgb',
|
|
imdecode_backend='cv2'),
|
|
dict(
|
|
type='LoadImageFromFile',
|
|
key='gt',
|
|
color_type='color',
|
|
channel_order='rgb',
|
|
imdecode_backend='cv2'),
|
|
dict(type='ToTensor', keys=['img', 'gt']),
|
|
dict(type='PackEditInputs')
|
|
]
|
|
|
|
# test config for Set5
|
|
set5_data_root = 'data/Set5'
|
|
set5_dataloader = dict(
|
|
num_workers=4,
|
|
persistent_workers=False,
|
|
drop_last=False,
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
dataset=dict(
|
|
type='BasicImageDataset',
|
|
metainfo=dict(dataset_type='set5', task_name='sisr'),
|
|
data_root=set5_data_root,
|
|
data_prefix=dict(img='LRbicx4', gt='GTmod12'),
|
|
pipeline=test_pipeline))
|
|
set5_evaluator = [
|
|
dict(type='PSNR', crop_border=2, prefix='Set5'),
|
|
dict(type='SSIM', crop_border=2, prefix='Set5'),
|
|
]
|
|
|
|
set14_data_root = 'data/Set14'
|
|
set14_dataloader = dict(
|
|
num_workers=4,
|
|
persistent_workers=False,
|
|
drop_last=False,
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
dataset=dict(
|
|
type='BasicImageDataset',
|
|
metainfo=dict(dataset_type='set14', task_name='sisr'),
|
|
data_root=set5_data_root,
|
|
data_prefix=dict(img='LRbicx4', gt='GTmod12'),
|
|
pipeline=test_pipeline))
|
|
set14_evaluator = [
|
|
dict(type='PSNR', crop_border=2, prefix='Set14'),
|
|
dict(type='SSIM', crop_border=2, prefix='Set14'),
|
|
]
|
|
|
|
ut_data_root = 'tests/test_codebase/test_mmedit/data'
|
|
ut_dataloader = dict(
|
|
num_workers=4,
|
|
persistent_workers=False,
|
|
drop_last=False,
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
dataset=dict(
|
|
type='BasicImageDataset',
|
|
metainfo=dict(dataset_type='set14', task_name='sisr'),
|
|
data_root=ut_data_root,
|
|
data_prefix=dict(img='imgs', gt='imgs'),
|
|
pipeline=test_pipeline))
|
|
|
|
# test config for DIV2K
|
|
div2k_data_root = 'data/DIV2K'
|
|
div2k_dataloader = dict(
|
|
num_workers=4,
|
|
persistent_workers=False,
|
|
drop_last=False,
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
dataset=dict(
|
|
type='BasicImageDataset',
|
|
ann_file='meta_info_DIV2K100sub_GT.txt',
|
|
metainfo=dict(dataset_type='div2k', task_name='sisr'),
|
|
data_root=div2k_data_root,
|
|
data_prefix=dict(
|
|
img='DIV2K_train_LR_bicubic/X4_sub', gt='DIV2K_train_HR_sub'),
|
|
# filename_tmpl=dict(img='{}_x4', gt='{}'),
|
|
pipeline=test_pipeline))
|
|
div2k_evaluator = [
|
|
dict(type='PSNR', crop_border=2, prefix='DIV2K'),
|
|
dict(type='SSIM', crop_border=2, prefix='DIV2K'),
|
|
]
|
|
|
|
# test config
|
|
test_cfg = dict(type='MultiTestLoop')
|
|
test_dataloader = [ut_dataloader, ut_dataloader]
|
|
test_evaluator = [set5_evaluator, set14_evaluator]
|