hanrui1sensetime 5c87dd9565
[2.0] Support mmedit 2.0 (#1017)
* mmcv.Config -> mmengine Config

* support mmedit part

* add rewriter for BaseEditModels

* fix visualizer

* mmedit visualization

* remove unused code

* fix realesrgan

* fix trt

* support MultiTestLoop; rewriter fix mmediting bugs; fix ut

* fix uts

* fix mmedit sdk

* fix regression test(part)

* fix torchscript

* part of fix regression test

* fix checkenv.py

* fix test.py for mmedit2.0

* support for mmedit

* fix regression_test

* fix check copyright ci

* fix isort

* fix docformatter

* fix yapf

* fix tests

* fix sdk after 1040

* add a file for ut

* fix docformatter

* fix export info

* fix super_resolution

* fix test.py

* stage configs

* remove unused code

* remove rewriter of multitestloop

* fix yapf
2022-09-20 19:22:55 +08:00

255 lines
7.2 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
default_scope = 'mmedit'
save_dir = './work_dirs'
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=100),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(
type='CheckpointHook',
interval=5000,
out_dir=save_dir,
by_epoch=False,
max_keep_ckpts=10,
save_best='PSNR',
rule='greater',
),
sampler_seed=dict(type='DistSamplerSeedHook'),
)
env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=4),
dist_cfg=dict(backend='nccl'),
)
log_level = 'INFO'
log_processor = dict(type='LogProcessor', window_size=100, by_epoch=False)
load_from = None
resume = False
experiment_name = 'srcnn_x4k915_1xb16-1000k_div2k'
work_dir = f'./work_dirs/{experiment_name}'
save_dir = './work_dirs/'
scale = 4
# model settings
model = dict(
type='BaseEditModel',
generator=dict(
type='SRCNNNet',
channels=(3, 64, 32, 3),
kernel_sizes=(9, 1, 5),
upscale_factor=scale),
pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'),
train_cfg=dict(),
test_cfg=dict(metrics=['PSNR'], crop_border=scale),
data_preprocessor=dict(
type='EditDataPreprocessor',
mean=[0., 0., 0.],
std=[255., 255., 255.],
))
train_pipeline = [
dict(
type='LoadImageFromFile',
key='img',
color_type='color',
channel_order='rgb',
imdecode_backend='cv2'),
dict(
type='LoadImageFromFile',
key='gt',
color_type='color',
channel_order='rgb',
imdecode_backend='cv2'),
dict(type='SetValues', dictionary=dict(scale=scale)),
dict(type='PairedRandomCrop', gt_patch_size=128),
dict(
type='Flip',
keys=['img', 'gt'],
flip_ratio=0.5,
direction='horizontal'),
dict(
type='Flip', keys=['img', 'gt'], flip_ratio=0.5, direction='vertical'),
dict(type='RandomTransposeHW', keys=['img', 'gt'], transpose_ratio=0.5),
dict(type='ToTensor', keys=['img', 'gt']),
dict(type='PackEditInputs')
]
val_pipeline = [
dict(
type='LoadImageFromFile',
key='img',
color_type='color',
channel_order='rgb',
imdecode_backend='cv2'),
dict(
type='LoadImageFromFile',
key='gt',
color_type='color',
channel_order='rgb',
imdecode_backend='cv2'),
dict(type='ToTensor', keys=['img', 'gt']),
dict(type='PackEditInputs')
]
# dataset settings
dataset_type = 'BasicImageDataset'
data_root = 'data'
train_dataloader = dict(
num_workers=4,
persistent_workers=False,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
ann_file='meta_info_DIV2K800sub_GT.txt',
metainfo=dict(dataset_type='div2k', task_name='sisr'),
data_root=data_root + '/DIV2K',
data_prefix=dict(
img='DIV2K_train_LR_bicubic/X4_sub', gt='DIV2K_train_HR_sub'),
filename_tmpl=dict(img='{}', gt='{}'),
pipeline=train_pipeline))
val_dataloader = dict(
num_workers=4,
persistent_workers=False,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
metainfo=dict(dataset_type='set5', task_name='sisr'),
data_root=data_root + '/Set5',
data_prefix=dict(img='LRbicx4', gt='GTmod12'),
pipeline=val_pipeline))
val_evaluator = [
dict(type='MAE'),
dict(type='PSNR', crop_border=scale),
dict(type='SSIM', crop_border=scale),
]
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=1000000, val_interval=5000)
val_cfg = dict(type='ValLoop')
# optimizer
optim_wrapper = dict(
constructor='DefaultOptimWrapperConstructor',
type='OptimWrapper',
optimizer=dict(type='Adam', lr=2e-4, betas=(0.9, 0.99)))
# learning policy
param_scheduler = dict(
type='CosineRestartLR',
by_epoch=False,
periods=[250000, 250000, 250000, 250000],
restart_weights=[1, 1, 1, 1],
eta_min=1e-7)
default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
interval=5000,
save_optimizer=True,
by_epoch=False,
out_dir=save_dir,
),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=100),
param_scheduler=dict(type='ParamSchedulerHook'),
sampler_seed=dict(type='DistSamplerSeedHook'),
)
test_pipeline = [
dict(
type='LoadImageFromFile',
key='img',
color_type='color',
channel_order='rgb',
imdecode_backend='cv2'),
dict(
type='LoadImageFromFile',
key='gt',
color_type='color',
channel_order='rgb',
imdecode_backend='cv2'),
dict(type='ToTensor', keys=['img', 'gt']),
dict(type='PackEditInputs')
]
# test config for Set5
set5_data_root = 'data/Set5'
set5_dataloader = dict(
num_workers=4,
persistent_workers=False,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='BasicImageDataset',
metainfo=dict(dataset_type='set5', task_name='sisr'),
data_root=set5_data_root,
data_prefix=dict(img='LRbicx4', gt='GTmod12'),
pipeline=test_pipeline))
set5_evaluator = [
dict(type='PSNR', crop_border=2, prefix='Set5'),
dict(type='SSIM', crop_border=2, prefix='Set5'),
]
set14_data_root = 'data/Set14'
set14_dataloader = dict(
num_workers=4,
persistent_workers=False,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='BasicImageDataset',
metainfo=dict(dataset_type='set14', task_name='sisr'),
data_root=set5_data_root,
data_prefix=dict(img='LRbicx4', gt='GTmod12'),
pipeline=test_pipeline))
set14_evaluator = [
dict(type='PSNR', crop_border=2, prefix='Set14'),
dict(type='SSIM', crop_border=2, prefix='Set14'),
]
ut_data_root = 'tests/test_codebase/test_mmedit/data'
ut_dataloader = dict(
num_workers=4,
persistent_workers=False,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='BasicImageDataset',
metainfo=dict(dataset_type='set14', task_name='sisr'),
data_root=ut_data_root,
data_prefix=dict(img='imgs', gt='imgs'),
pipeline=test_pipeline))
# test config for DIV2K
div2k_data_root = 'data/DIV2K'
div2k_dataloader = dict(
num_workers=4,
persistent_workers=False,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='BasicImageDataset',
ann_file='meta_info_DIV2K100sub_GT.txt',
metainfo=dict(dataset_type='div2k', task_name='sisr'),
data_root=div2k_data_root,
data_prefix=dict(
img='DIV2K_train_LR_bicubic/X4_sub', gt='DIV2K_train_HR_sub'),
# filename_tmpl=dict(img='{}_x4', gt='{}'),
pipeline=test_pipeline))
div2k_evaluator = [
dict(type='PSNR', crop_border=2, prefix='DIV2K'),
dict(type='SSIM', crop_border=2, prefix='DIV2K'),
]
# test config
test_cfg = dict(type='MultiTestLoop')
test_dataloader = [ut_dataloader, ut_dataloader]
test_evaluator = [set5_evaluator, set14_evaluator]