diff --git a/basicsr/demo.py b/basicsr/demo.py index 2f89966..a6fb539 100644 --- a/basicsr/demo.py +++ b/basicsr/demo.py @@ -9,7 +9,7 @@ import torch # from basicsr.data import create_dataloader, create_dataset from basicsr.models import create_model from basicsr.train import parse_options -from basicsr.utils import FileClient, imfrombytes, img2tensor, padding +from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, tensor2img, imwrite # from basicsr.utils import (get_env_info, get_root_logger, get_time_str, # make_exp_dirs) @@ -37,10 +37,24 @@ def main(): ## 2. run inference + opt['dist'] = False model = create_model(opt) - model.single_image_inference(img, output_path) - print('inference {} .. finished.'.format(img_path)) + model.feed_data(data={'lq': img.unsqueeze(dim=0)}) + + if model.opt['val'].get('grids', False): + model.grids() + + model.test() + + if model.opt['val'].get('grids', False): + model.grids_inverse() + + visuals = model.get_current_visuals() + sr_img = tensor2img([visuals['result']]) + imwrite(sr_img, output_path) + + print(f'inference {img_path} .. finished. saved to {output_path}') if __name__ == '__main__': main() diff --git a/basicsr/train.py b/basicsr/train.py index 7b58a9f..9cc8f2a 100644 --- a/basicsr/train.py +++ b/basicsr/train.py @@ -35,6 +35,10 @@ def parse_options(is_train=True): default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) + + parser.add_argument('--input_path', type=str, required=False, help='The path to the input image. For single image inference only.') + parser.add_argument('--output_path', type=str, required=False, help='The path to the output image. For single image inference only.') + args = parser.parse_args() opt = parse(args.opt, is_train=is_train) @@ -59,6 +63,12 @@ def parse_options(is_train=True): opt['manual_seed'] = seed set_random_seed(seed + opt['rank']) + if args.input_path is not None and args.output_path is not None: + opt['img_path'] = { + 'input_img': args.input_path, + 'output_img': args.output_path + } + return opt diff --git a/demo/blurry.png b/demo/blurry.png new file mode 100644 index 0000000..23b31dd Binary files /dev/null and b/demo/blurry.png differ diff --git a/demo/deblur_img.png b/demo/deblur_img.png new file mode 100644 index 0000000..c2d6b67 Binary files /dev/null and b/demo/deblur_img.png differ diff --git a/demo/denoise_img.png b/demo/denoise_img.png new file mode 100644 index 0000000..2f74439 Binary files /dev/null and b/demo/denoise_img.png differ diff --git a/demo/noisy.png b/demo/noisy.png new file mode 100644 index 0000000..f709a31 Binary files /dev/null and b/demo/noisy.png differ diff --git a/experiments/pretrained_models/README.md b/experiments/pretrained_models/README.md index d7f613e..26367fb 100644 --- a/experiments/pretrained_models/README.md +++ b/experiments/pretrained_models/README.md @@ -1,5 +1,4 @@ ### Pretrained NAFNet Models --- -* [NAFNet-SIDD-width64](https://drive.google.com/file/d/14Fht1QQJ2gMlk4N1ERCRuElg8JfjrWWR/view?usp=sharing) -* [NAFNet-GoPro-width64](https://drive.google.com/file/d/1S0PVRbyTakYY9a82kujgZLbMihfNBLfC/view?usp=sharing) +please refer to https://github.com/megvii-research/NAFNet/#results-and-pre-trained-models, and download the pretrained models into ./experiments/pretrained_models diff --git a/readme.md b/readme.md index 2effb3d..3082255 100644 --- a/readme.md +++ b/readme.md @@ -35,8 +35,18 @@ python setup.py develop --no_cuda_ext ### Quick Start * Image Denoise Colab Demo: [google colab logo](https://colab.research.google.com/drive/1dkO5AyktmBoWwxBwoKFUurIDn0m4qDXT?usp=sharing) * Image Deblur Colab Demo: [google colab logo](https://colab.research.google.com/drive/1yR2ClVuMefisH12d_srXMhHnHwwA1YmU?usp=sharing) - - +* Single Image Inference Demo: + * Image Denoise: + ``` + python basicsr/demo.py -opt options/test/SIDD/NAFNet-width64.yml --input_path ./demo/noisy.png --output_path ./demo/denoise_img.png + ``` + * Image Deblur: + ``` + python basicsr/demo.py -opt options/test/GoPro/NAFNet-width64.yml --input_path ./demo/blurry.png --output_path ./demo/deblur_img.png + ``` + * ```--input_path```: the path of the degraded image + * ```--output_path```: the path to save the predicted image + * [pretrained models](https://github.com/megvii-research/NAFNet/#results-and-pre-trained-models) should be downloaded. ### Results and Pre-trained Models @@ -78,8 +88,8 @@ If you have any questions, please contact chenliangyu@megvii.com or chuxiaojie@m
statistics - + ![visitors](https://visitor-badge.glitch.me/badge?page_id=megvii-research/NAFNet) - +