diff --git a/basicsr/demo.py b/basicsr/demo.py
index 2f89966..a6fb539 100644
--- a/basicsr/demo.py
+++ b/basicsr/demo.py
@@ -9,7 +9,7 @@ import torch
# from basicsr.data import create_dataloader, create_dataset
from basicsr.models import create_model
from basicsr.train import parse_options
-from basicsr.utils import FileClient, imfrombytes, img2tensor, padding
+from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, tensor2img, imwrite
# from basicsr.utils import (get_env_info, get_root_logger, get_time_str,
# make_exp_dirs)
@@ -37,10 +37,24 @@ def main():
## 2. run inference
+ opt['dist'] = False
model = create_model(opt)
- model.single_image_inference(img, output_path)
- print('inference {} .. finished.'.format(img_path))
+ model.feed_data(data={'lq': img.unsqueeze(dim=0)})
+
+ if model.opt['val'].get('grids', False):
+ model.grids()
+
+ model.test()
+
+ if model.opt['val'].get('grids', False):
+ model.grids_inverse()
+
+ visuals = model.get_current_visuals()
+ sr_img = tensor2img([visuals['result']])
+ imwrite(sr_img, output_path)
+
+ print(f'inference {img_path} .. finished. saved to {output_path}')
if __name__ == '__main__':
main()
diff --git a/basicsr/train.py b/basicsr/train.py
index 7b58a9f..9cc8f2a 100644
--- a/basicsr/train.py
+++ b/basicsr/train.py
@@ -35,6 +35,10 @@ def parse_options(is_train=True):
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
+
+ parser.add_argument('--input_path', type=str, required=False, help='The path to the input image. For single image inference only.')
+ parser.add_argument('--output_path', type=str, required=False, help='The path to the output image. For single image inference only.')
+
args = parser.parse_args()
opt = parse(args.opt, is_train=is_train)
@@ -59,6 +63,12 @@ def parse_options(is_train=True):
opt['manual_seed'] = seed
set_random_seed(seed + opt['rank'])
+ if args.input_path is not None and args.output_path is not None:
+ opt['img_path'] = {
+ 'input_img': args.input_path,
+ 'output_img': args.output_path
+ }
+
return opt
diff --git a/demo/blurry.png b/demo/blurry.png
new file mode 100644
index 0000000..23b31dd
Binary files /dev/null and b/demo/blurry.png differ
diff --git a/demo/deblur_img.png b/demo/deblur_img.png
new file mode 100644
index 0000000..c2d6b67
Binary files /dev/null and b/demo/deblur_img.png differ
diff --git a/demo/denoise_img.png b/demo/denoise_img.png
new file mode 100644
index 0000000..2f74439
Binary files /dev/null and b/demo/denoise_img.png differ
diff --git a/demo/noisy.png b/demo/noisy.png
new file mode 100644
index 0000000..f709a31
Binary files /dev/null and b/demo/noisy.png differ
diff --git a/experiments/pretrained_models/README.md b/experiments/pretrained_models/README.md
index d7f613e..26367fb 100644
--- a/experiments/pretrained_models/README.md
+++ b/experiments/pretrained_models/README.md
@@ -1,5 +1,4 @@
### Pretrained NAFNet Models
---
-* [NAFNet-SIDD-width64](https://drive.google.com/file/d/14Fht1QQJ2gMlk4N1ERCRuElg8JfjrWWR/view?usp=sharing)
-* [NAFNet-GoPro-width64](https://drive.google.com/file/d/1S0PVRbyTakYY9a82kujgZLbMihfNBLfC/view?usp=sharing)
+please refer to https://github.com/megvii-research/NAFNet/#results-and-pre-trained-models, and download the pretrained models into ./experiments/pretrained_models
diff --git a/readme.md b/readme.md
index 2effb3d..3082255 100644
--- a/readme.md
+++ b/readme.md
@@ -35,8 +35,18 @@ python setup.py develop --no_cuda_ext
### Quick Start
* Image Denoise Colab Demo: [
](https://colab.research.google.com/drive/1dkO5AyktmBoWwxBwoKFUurIDn0m4qDXT?usp=sharing)
* Image Deblur Colab Demo: [
](https://colab.research.google.com/drive/1yR2ClVuMefisH12d_srXMhHnHwwA1YmU?usp=sharing)
-
-
+* Single Image Inference Demo:
+ * Image Denoise:
+ ```
+ python basicsr/demo.py -opt options/test/SIDD/NAFNet-width64.yml --input_path ./demo/noisy.png --output_path ./demo/denoise_img.png
+ ```
+ * Image Deblur:
+ ```
+ python basicsr/demo.py -opt options/test/GoPro/NAFNet-width64.yml --input_path ./demo/blurry.png --output_path ./demo/deblur_img.png
+ ```
+ * ```--input_path```: the path of the degraded image
+ * ```--output_path```: the path to save the predicted image
+ * [pretrained models](https://github.com/megvii-research/NAFNet/#results-and-pre-trained-models) should be downloaded.
### Results and Pre-trained Models
@@ -78,8 +88,8 @@ If you have any questions, please contact chenliangyu@megvii.com or chuxiaojie@m
statistics
-
+

-
+