diff --git a/.gitignore b/.gitignore index 032d2dd..2dea1cc 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ logs/ *__pycache__* *.sh datasets -basicsr.egg-info \ No newline at end of file +basicsr.egg-info +.eggs/ \ No newline at end of file diff --git a/basicsr/models/image_restoration_model.py b/basicsr/models/image_restoration_model.py index 1eec564..6cb39eb 100644 --- a/basicsr/models/image_restoration_model.py +++ b/basicsr/models/image_restoration_model.py @@ -375,27 +375,120 @@ class ImageRestorationModel(BaseModel): for key in metrics_dict: metrics_dict[key] /= cnt + self.metric_results = metrics_dict self._log_validation_metric_values(current_iter, dataloader.dataset.opt['name'], - tb_logger, metrics_dict) + tb_logger) return 0. - def nondist_validation(self, *args, **kwargs): - logger = get_root_logger() - logger.warning('nondist_validation is not implemented. Run dist_validation.') - self.dist_validation(*args, **kwargs) + # def nondist_validation(self, *args, **kwargs): + # logger = get_root_logger() + # logger.warning('nondist_validation is not implemented. Run dist_validation.') + # self.dist_validation(*args, **kwargs) + def nondist_validation(self, dataloader, current_iter, tb_logger, + save_img, rgb2bgr, use_image): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + if with_metrics: + self.metric_results = { + metric: 0 + for metric in self.opt['val']['metrics'].keys() + } + pbar = tqdm(total=len(dataloader), unit='image') + + cnt = 0 + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + # if img_name[-1] != '9': + # continue + + # print('val_data .. ', val_data['lq'].size(), val_data['gt'].size()) + self.feed_data(val_data) + if self.opt['val'].get('grids', False): + self.grids() + + self.test() + + if self.opt['val'].get('grids', False): + self.grids_inverse() + + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['result']], rgb2bgr=rgb2bgr) + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']], rgb2bgr=rgb2bgr) + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + + if self.opt['is_train']: + + save_img_path = osp.join(self.opt['path']['visualization'], + img_name, + f'{img_name}_{current_iter}.png') + + save_gt_img_path = osp.join(self.opt['path']['visualization'], + img_name, + f'{img_name}_{current_iter}_gt.png') + else: + + save_img_path = osp.join( + self.opt['path']['visualization'], dataset_name, + f'{img_name}.png') + save_gt_img_path = osp.join( + self.opt['path']['visualization'], dataset_name, + f'{img_name}_gt.png') + + imwrite(sr_img, save_img_path) + imwrite(gt_img, save_gt_img_path) + + if with_metrics: + # calculate metrics + opt_metric = deepcopy(self.opt['val']['metrics']) + if use_image: + for name, opt_ in opt_metric.items(): + metric_type = opt_.pop('type') + self.metric_results[name] += getattr( + metric_module, metric_type)(sr_img, gt_img, **opt_) + else: + for name, opt_ in opt_metric.items(): + metric_type = opt_.pop('type') + self.metric_results[name] += getattr( + metric_module, metric_type)(visuals['result'], visuals['gt'], **opt_) + + pbar.update(1) + pbar.set_description(f'Test {img_name}') + cnt += 1 + # if cnt == 300: + # break + pbar.close() + + current_metric = 0. + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= cnt + current_metric = self.metric_results[metric] + + self._log_validation_metric_values(current_iter, dataset_name, + tb_logger) + return current_metric def _log_validation_metric_values(self, current_iter, dataset_name, - tb_logger, metric_dict): + tb_logger): log_str = f'Validation {dataset_name}, \t' - for metric, value in metric_dict.items(): + for metric, value in self.metric_results.items(): log_str += f'\t # {metric}: {value:.4f}' logger = get_root_logger() logger.info(log_str) log_dict = OrderedDict() # for name, value in loss_dict.items(): - for metric, value in metric_dict.items(): + for metric, value in self.metric_results.items(): log_dict[f'm_{metric}'] = value self.log_dict = log_dict diff --git a/basicsr/train.py b/basicsr/train.py index 9cc8f2a..9c01bb2 100644 --- a/basicsr/train.py +++ b/basicsr/train.py @@ -155,7 +155,7 @@ def main(): import os try: states = os.listdir(state_folder_path) - except: + except Exception: states = [] resume_state = None diff --git a/basicsr/version.py b/basicsr/version.py index 2c4fdd3..773eff1 100644 --- a/basicsr/version.py +++ b/basicsr/version.py @@ -1,5 +1,5 @@ # GENERATED VERSION FILE -# TIME: Mon Apr 18 21:35:20 2022 -__version__ = '1.2.0+386ca20' +# TIME: Fri May 5 22:09:31 2023 +__version__ = '1.2.0+50cb149' short_version = '1.2.0' version_info = (1, 2, 0) diff --git a/options/test/dashboard/NAFNet-width32.yml b/options/test/dashboard/NAFNet-width32.yml new file mode 100644 index 0000000..e68eefe --- /dev/null +++ b/options/test/dashboard/NAFNet-width32.yml @@ -0,0 +1,60 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# general settings +name: NAFNet-Dashboard-width32-test +model_type: ImageRestorationModel +scale: 1 +num_gpu: 1 # set num_gpu: 0 for cpu mode +manual_seed: 10 + +# dataset and data loader settings +datasets: + + test: + name: dashboard-test + type: PairedImageDataset + + dataroot_gt: E:\HCH\graduate\research_1\NAFNet\datasets\dashboard.v2i.voc\val\gt + dataroot_lq: E:\HCH\graduate\research_1\NAFNet\datasets\dashboard.v2i.voc\val\lq + + io_backend: + type: disk + +# network structures +network_g: + type: NAFNetLocal + width: 32 + enc_blk_nums: [1, 1, 1, 14] + middle_blk_num: 1 + dec_blk_nums: [1, 1, 1, 1] + +# path +path: + pretrain_network_g: experiments/pretrained_models/NAFNet-dashboard-width32.pth + strict_load_g: true + resume_state: ~ + +# validation settings +val: + save_img: true + grids: false + + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + ssim: + type: calculate_ssim + crop_border: 0 + test_y_channel: false + +# dist training settings +# dist_params: +# backend: nccl +# port: 29500 diff --git a/options/train/dashboard/NAFNet-width32.yml b/options/train/dashboard/NAFNet-width32.yml new file mode 100644 index 0000000..80f921f --- /dev/null +++ b/options/train/dashboard/NAFNet-width32.yml @@ -0,0 +1,108 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +# general settings +name: NAFNet-Dashboard-width32 +model_type: ImageRestorationModel +scale: 1 +num_gpu: 1 +manual_seed: 10 + +datasets: + train: + name: dashboard-train + type: PairedImageDataset + dataroot_gt: E:\HCH\graduate\research_1\NAFNet\datasets\dashboard.v2i.voc\train\gt + dataroot_lq: E:\HCH\graduate\research_1\NAFNet\datasets\dashboard.v2i.voc\train\lq + + filename_tmpl: '{}' + io_backend: + type: disk + + gt_size: 128 # Cropped patched size for gt patches. + use_flip: true + use_rot: true + + # data loader + use_shuffle: true + num_worker_per_gpu: 20 + batch_size_per_gpu: 80 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + name: dashboard-val + type: PairedImageDataset + dataroot_gt: E:\HCH\graduate\research_1\NAFNet\datasets\dashboard.v2i.voc\val\gt + dataroot_lq: E:\HCH\graduate\research_1\NAFNet\datasets\dashboard.v2i.voc\val\lq + io_backend: + type: disk + + +network_g: + type: NAFNetLocal # network name, see in models/archs/NAFNet_arch.py + width: 32 + enc_blk_nums: [1, 1, 1, 14] + middle_blk_num: 1 + dec_blk_nums: [1, 1, 1, 1] + +# path +path: + pretrain_network_g: ~ + strict_load_g: true + resume_state: ~ + +# training settings +train: + optim_g: + type: AdamW + lr: !!float 1e-3 + weight_decay: !!float 1e-3 + betas: [0.9, 0.9] + + scheduler: + type: TrueCosineAnnealingLR + T_max: 1000 + eta_min: !!float 1e-7 + + total_iter: 1000 + warmup_iter: -1 # no warm up + + # losses + pixel_opt: + type: PSNRLoss + loss_weight: 1 + reduction: mean + +# validation settings +val: + val_freq: !!float 2e2 + save_img: false + + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 0 + test_y_channel: false + ssim: + type: calculate_ssim + crop_border: 0 + test_y_channel: false + +# logging settings +logger: + print_freq: 10 + save_checkpoint_freq: !!float 2e2 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +# dist_params: +# backend: nccl +# port: 29500