add dashborad dataset configuration, add nondist_validation code
parent
50cb1496d6
commit
a8d01a7e3b
|
@ -6,4 +6,5 @@ logs/
|
|||
*__pycache__*
|
||||
*.sh
|
||||
datasets
|
||||
basicsr.egg-info
|
||||
basicsr.egg-info
|
||||
.eggs/
|
|
@ -375,27 +375,120 @@ class ImageRestorationModel(BaseModel):
|
|||
for key in metrics_dict:
|
||||
metrics_dict[key] /= cnt
|
||||
|
||||
self.metric_results = metrics_dict
|
||||
self._log_validation_metric_values(current_iter, dataloader.dataset.opt['name'],
|
||||
tb_logger, metrics_dict)
|
||||
tb_logger)
|
||||
return 0.
|
||||
|
||||
def nondist_validation(self, *args, **kwargs):
|
||||
logger = get_root_logger()
|
||||
logger.warning('nondist_validation is not implemented. Run dist_validation.')
|
||||
self.dist_validation(*args, **kwargs)
|
||||
# def nondist_validation(self, *args, **kwargs):
|
||||
# logger = get_root_logger()
|
||||
# logger.warning('nondist_validation is not implemented. Run dist_validation.')
|
||||
# self.dist_validation(*args, **kwargs)
|
||||
|
||||
def nondist_validation(self, dataloader, current_iter, tb_logger,
|
||||
save_img, rgb2bgr, use_image):
|
||||
dataset_name = dataloader.dataset.opt['name']
|
||||
with_metrics = self.opt['val'].get('metrics') is not None
|
||||
if with_metrics:
|
||||
self.metric_results = {
|
||||
metric: 0
|
||||
for metric in self.opt['val']['metrics'].keys()
|
||||
}
|
||||
pbar = tqdm(total=len(dataloader), unit='image')
|
||||
|
||||
cnt = 0
|
||||
|
||||
for idx, val_data in enumerate(dataloader):
|
||||
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
||||
# if img_name[-1] != '9':
|
||||
# continue
|
||||
|
||||
# print('val_data .. ', val_data['lq'].size(), val_data['gt'].size())
|
||||
self.feed_data(val_data)
|
||||
if self.opt['val'].get('grids', False):
|
||||
self.grids()
|
||||
|
||||
self.test()
|
||||
|
||||
if self.opt['val'].get('grids', False):
|
||||
self.grids_inverse()
|
||||
|
||||
visuals = self.get_current_visuals()
|
||||
sr_img = tensor2img([visuals['result']], rgb2bgr=rgb2bgr)
|
||||
if 'gt' in visuals:
|
||||
gt_img = tensor2img([visuals['gt']], rgb2bgr=rgb2bgr)
|
||||
del self.gt
|
||||
|
||||
# tentative for out of GPU memory
|
||||
del self.lq
|
||||
del self.output
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if save_img:
|
||||
|
||||
if self.opt['is_train']:
|
||||
|
||||
save_img_path = osp.join(self.opt['path']['visualization'],
|
||||
img_name,
|
||||
f'{img_name}_{current_iter}.png')
|
||||
|
||||
save_gt_img_path = osp.join(self.opt['path']['visualization'],
|
||||
img_name,
|
||||
f'{img_name}_{current_iter}_gt.png')
|
||||
else:
|
||||
|
||||
save_img_path = osp.join(
|
||||
self.opt['path']['visualization'], dataset_name,
|
||||
f'{img_name}.png')
|
||||
save_gt_img_path = osp.join(
|
||||
self.opt['path']['visualization'], dataset_name,
|
||||
f'{img_name}_gt.png')
|
||||
|
||||
imwrite(sr_img, save_img_path)
|
||||
imwrite(gt_img, save_gt_img_path)
|
||||
|
||||
if with_metrics:
|
||||
# calculate metrics
|
||||
opt_metric = deepcopy(self.opt['val']['metrics'])
|
||||
if use_image:
|
||||
for name, opt_ in opt_metric.items():
|
||||
metric_type = opt_.pop('type')
|
||||
self.metric_results[name] += getattr(
|
||||
metric_module, metric_type)(sr_img, gt_img, **opt_)
|
||||
else:
|
||||
for name, opt_ in opt_metric.items():
|
||||
metric_type = opt_.pop('type')
|
||||
self.metric_results[name] += getattr(
|
||||
metric_module, metric_type)(visuals['result'], visuals['gt'], **opt_)
|
||||
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Test {img_name}')
|
||||
cnt += 1
|
||||
# if cnt == 300:
|
||||
# break
|
||||
pbar.close()
|
||||
|
||||
current_metric = 0.
|
||||
if with_metrics:
|
||||
for metric in self.metric_results.keys():
|
||||
self.metric_results[metric] /= cnt
|
||||
current_metric = self.metric_results[metric]
|
||||
|
||||
self._log_validation_metric_values(current_iter, dataset_name,
|
||||
tb_logger)
|
||||
return current_metric
|
||||
|
||||
def _log_validation_metric_values(self, current_iter, dataset_name,
|
||||
tb_logger, metric_dict):
|
||||
tb_logger):
|
||||
log_str = f'Validation {dataset_name}, \t'
|
||||
for metric, value in metric_dict.items():
|
||||
for metric, value in self.metric_results.items():
|
||||
log_str += f'\t # {metric}: {value:.4f}'
|
||||
logger = get_root_logger()
|
||||
logger.info(log_str)
|
||||
|
||||
log_dict = OrderedDict()
|
||||
# for name, value in loss_dict.items():
|
||||
for metric, value in metric_dict.items():
|
||||
for metric, value in self.metric_results.items():
|
||||
log_dict[f'm_{metric}'] = value
|
||||
|
||||
self.log_dict = log_dict
|
||||
|
|
|
@ -155,7 +155,7 @@ def main():
|
|||
import os
|
||||
try:
|
||||
states = os.listdir(state_folder_path)
|
||||
except:
|
||||
except Exception:
|
||||
states = []
|
||||
|
||||
resume_state = None
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# GENERATED VERSION FILE
|
||||
# TIME: Mon Apr 18 21:35:20 2022
|
||||
__version__ = '1.2.0+386ca20'
|
||||
# TIME: Fri May 5 22:09:31 2023
|
||||
__version__ = '1.2.0+50cb149'
|
||||
short_version = '1.2.0'
|
||||
version_info = (1, 2, 0)
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# ------------------------------------------------------------------------
|
||||
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||||
# Copyright 2018-2020 BasicSR Authors
|
||||
# ------------------------------------------------------------------------
|
||||
# general settings
|
||||
name: NAFNet-Dashboard-width32-test
|
||||
model_type: ImageRestorationModel
|
||||
scale: 1
|
||||
num_gpu: 1 # set num_gpu: 0 for cpu mode
|
||||
manual_seed: 10
|
||||
|
||||
# dataset and data loader settings
|
||||
datasets:
|
||||
|
||||
test:
|
||||
name: dashboard-test
|
||||
type: PairedImageDataset
|
||||
|
||||
dataroot_gt: E:\HCH\graduate\research_1\NAFNet\datasets\dashboard.v2i.voc\val\gt
|
||||
dataroot_lq: E:\HCH\graduate\research_1\NAFNet\datasets\dashboard.v2i.voc\val\lq
|
||||
|
||||
io_backend:
|
||||
type: disk
|
||||
|
||||
# network structures
|
||||
network_g:
|
||||
type: NAFNetLocal
|
||||
width: 32
|
||||
enc_blk_nums: [1, 1, 1, 14]
|
||||
middle_blk_num: 1
|
||||
dec_blk_nums: [1, 1, 1, 1]
|
||||
|
||||
# path
|
||||
path:
|
||||
pretrain_network_g: experiments/pretrained_models/NAFNet-dashboard-width32.pth
|
||||
strict_load_g: true
|
||||
resume_state: ~
|
||||
|
||||
# validation settings
|
||||
val:
|
||||
save_img: true
|
||||
grids: false
|
||||
|
||||
|
||||
metrics:
|
||||
psnr: # metric name, can be arbitrary
|
||||
type: calculate_psnr
|
||||
crop_border: 0
|
||||
test_y_channel: false
|
||||
ssim:
|
||||
type: calculate_ssim
|
||||
crop_border: 0
|
||||
test_y_channel: false
|
||||
|
||||
# dist training settings
|
||||
# dist_params:
|
||||
# backend: nccl
|
||||
# port: 29500
|
|
@ -0,0 +1,108 @@
|
|||
# ------------------------------------------------------------------------
|
||||
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||||
# Copyright 2018-2020 BasicSR Authors
|
||||
# ------------------------------------------------------------------------
|
||||
# general settings
|
||||
name: NAFNet-Dashboard-width32
|
||||
model_type: ImageRestorationModel
|
||||
scale: 1
|
||||
num_gpu: 1
|
||||
manual_seed: 10
|
||||
|
||||
datasets:
|
||||
train:
|
||||
name: dashboard-train
|
||||
type: PairedImageDataset
|
||||
dataroot_gt: E:\HCH\graduate\research_1\NAFNet\datasets\dashboard.v2i.voc\train\gt
|
||||
dataroot_lq: E:\HCH\graduate\research_1\NAFNet\datasets\dashboard.v2i.voc\train\lq
|
||||
|
||||
filename_tmpl: '{}'
|
||||
io_backend:
|
||||
type: disk
|
||||
|
||||
gt_size: 128 # Cropped patched size for gt patches.
|
||||
use_flip: true
|
||||
use_rot: true
|
||||
|
||||
# data loader
|
||||
use_shuffle: true
|
||||
num_worker_per_gpu: 20
|
||||
batch_size_per_gpu: 80
|
||||
dataset_enlarge_ratio: 1
|
||||
prefetch_mode: ~
|
||||
|
||||
val:
|
||||
name: dashboard-val
|
||||
type: PairedImageDataset
|
||||
dataroot_gt: E:\HCH\graduate\research_1\NAFNet\datasets\dashboard.v2i.voc\val\gt
|
||||
dataroot_lq: E:\HCH\graduate\research_1\NAFNet\datasets\dashboard.v2i.voc\val\lq
|
||||
io_backend:
|
||||
type: disk
|
||||
|
||||
|
||||
network_g:
|
||||
type: NAFNetLocal # network name, see in models/archs/NAFNet_arch.py
|
||||
width: 32
|
||||
enc_blk_nums: [1, 1, 1, 14]
|
||||
middle_blk_num: 1
|
||||
dec_blk_nums: [1, 1, 1, 1]
|
||||
|
||||
# path
|
||||
path:
|
||||
pretrain_network_g: ~
|
||||
strict_load_g: true
|
||||
resume_state: ~
|
||||
|
||||
# training settings
|
||||
train:
|
||||
optim_g:
|
||||
type: AdamW
|
||||
lr: !!float 1e-3
|
||||
weight_decay: !!float 1e-3
|
||||
betas: [0.9, 0.9]
|
||||
|
||||
scheduler:
|
||||
type: TrueCosineAnnealingLR
|
||||
T_max: 1000
|
||||
eta_min: !!float 1e-7
|
||||
|
||||
total_iter: 1000
|
||||
warmup_iter: -1 # no warm up
|
||||
|
||||
# losses
|
||||
pixel_opt:
|
||||
type: PSNRLoss
|
||||
loss_weight: 1
|
||||
reduction: mean
|
||||
|
||||
# validation settings
|
||||
val:
|
||||
val_freq: !!float 2e2
|
||||
save_img: false
|
||||
|
||||
|
||||
metrics:
|
||||
psnr: # metric name, can be arbitrary
|
||||
type: calculate_psnr
|
||||
crop_border: 0
|
||||
test_y_channel: false
|
||||
ssim:
|
||||
type: calculate_ssim
|
||||
crop_border: 0
|
||||
test_y_channel: false
|
||||
|
||||
# logging settings
|
||||
logger:
|
||||
print_freq: 10
|
||||
save_checkpoint_freq: !!float 2e2
|
||||
use_tb_logger: true
|
||||
wandb:
|
||||
project: ~
|
||||
resume_id: ~
|
||||
|
||||
# dist training settings
|
||||
# dist_params:
|
||||
# backend: nccl
|
||||
# port: 29500
|
Loading…
Reference in New Issue