add demo for stereo sr
parent
9966dfa938
commit
34622eb121
|
@ -0,0 +1,118 @@
|
|||
# ------------------------------------------------------------------------
|
||||
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||||
# Copyright 2018-2020 BasicSR Authors
|
||||
# ------------------------------------------------------------------------
|
||||
import torch
|
||||
|
||||
# from basicsr.data import create_dataloader, create_dataset
|
||||
from basicsr.models import create_model
|
||||
from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, tensor2img, imwrite, set_random_seed
|
||||
|
||||
import argparse
|
||||
from basicsr.utils.options import dict2str, parse
|
||||
from basicsr.utils.dist_util import get_dist_info, init_dist
|
||||
import random
|
||||
|
||||
def parse_options(is_train=True):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'-opt', type=str, required=True, help='Path to option YAML file.')
|
||||
parser.add_argument(
|
||||
'--launcher',
|
||||
choices=['none', 'pytorch', 'slurm'],
|
||||
default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
||||
parser.add_argument('--input_l_path', type=str, required=True, help='The path to the input left image. For stereo image inference only.')
|
||||
parser.add_argument('--input_r_path', type=str, required=True, help='The path to the input right image. For stereo image inference only.')
|
||||
parser.add_argument('--output_l_path', type=str, required=True, help='The path to the output left image. For stereo image inference only.')
|
||||
parser.add_argument('--output_r_path', type=str, required=True, help='The path to the output right image. For stereo image inference only.')
|
||||
|
||||
args = parser.parse_args()
|
||||
opt = parse(args.opt, is_train=is_train)
|
||||
|
||||
# distributed settings
|
||||
if args.launcher == 'none':
|
||||
opt['dist'] = False
|
||||
print('Disable distributed.', flush=True)
|
||||
else:
|
||||
opt['dist'] = True
|
||||
if args.launcher == 'slurm' and 'dist_params' in opt:
|
||||
init_dist(args.launcher, **opt['dist_params'])
|
||||
else:
|
||||
init_dist(args.launcher)
|
||||
print('init dist .. ', args.launcher)
|
||||
|
||||
opt['rank'], opt['world_size'] = get_dist_info()
|
||||
|
||||
# random seed
|
||||
seed = opt.get('manual_seed')
|
||||
if seed is None:
|
||||
seed = random.randint(1, 10000)
|
||||
opt['manual_seed'] = seed
|
||||
set_random_seed(seed + opt['rank'])
|
||||
|
||||
opt['img_path'] = {
|
||||
'input_l': args.input_l_path,
|
||||
'input_r': args.input_r_path,
|
||||
'output_l': args.output_l_path,
|
||||
'output_r': args.output_r_path
|
||||
}
|
||||
|
||||
return opt
|
||||
|
||||
def imread(img_path):
|
||||
file_client = FileClient('disk')
|
||||
img_bytes = file_client.get(img_path, None)
|
||||
try:
|
||||
img = imfrombytes(img_bytes, float32=True)
|
||||
except:
|
||||
raise Exception("path {} not working".format(img_path))
|
||||
|
||||
img = img2tensor(img, bgr2rgb=True, float32=True)
|
||||
return img
|
||||
|
||||
def main():
|
||||
# parse options, set distributed setting, set ramdom seed
|
||||
opt = parse_options(is_train=False)
|
||||
|
||||
img_l_path = opt['img_path'].get('input_l')
|
||||
img_r_path = opt['img_path'].get('input_r')
|
||||
output_l_path = opt['img_path'].get('output_l')
|
||||
output_r_path = opt['img_path'].get('output_r')
|
||||
|
||||
## 1. read image
|
||||
img_l = imread(img_l_path)
|
||||
img_r = imread(img_r_path)
|
||||
img = torch.cat([img_l, img_r], dim=0)
|
||||
|
||||
## 2. run inference
|
||||
opt['dist'] = False
|
||||
model = create_model(opt)
|
||||
|
||||
model.feed_data(data={'lq': img.unsqueeze(dim=0)})
|
||||
|
||||
if model.opt['val'].get('grids', False):
|
||||
model.grids()
|
||||
|
||||
model.test()
|
||||
|
||||
if model.opt['val'].get('grids', False):
|
||||
model.grids_inverse()
|
||||
|
||||
visuals = model.get_current_visuals()
|
||||
sr_img_l = visuals['result'][:,:3]
|
||||
sr_img_r = visuals['result'][:,3:]
|
||||
sr_img_l, sr_img_r = tensor2img([sr_img_l, sr_img_r])
|
||||
imwrite(sr_img_l, output_l_path)
|
||||
imwrite(sr_img_r, output_r_path)
|
||||
|
||||
print(f'inference {img_l_path} .. finished. saved to {output_l_path}')
|
||||
print(f'inference {img_r_path} .. finished. saved to {output_r_path}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
Binary file not shown.
After Width: | Height: | Size: 120 KiB |
Binary file not shown.
After Width: | Height: | Size: 121 KiB |
Binary file not shown.
After Width: | Height: | Size: 1.6 MiB |
Binary file not shown.
After Width: | Height: | Size: 1.6 MiB |
12
readme.md
12
readme.md
|
@ -60,6 +60,18 @@ python setup.py develop --no_cuda_ext
|
|||
* ```--input_path```: the path of the degraded image
|
||||
* ```--output_path```: the path to save the predicted image
|
||||
* [pretrained models](https://github.com/megvii-research/NAFNet/#results-and-pre-trained-models) should be downloaded.
|
||||
* Stereo Image Inference Demo:
|
||||
* Stereo Image Super-resolution:
|
||||
```
|
||||
python basicsr/demo_ssr.py -opt options/test/NAFSSR/NAFSSR-L_4x.yml \
|
||||
--input_l_path ./demo/lr_img_l.png --input_r_path ./demo/lr_img_r.png \
|
||||
--output_l_path ./demo/sr_img_l.png --output_r_path ./demo/sr_img_r.png
|
||||
```
|
||||
* ```--input_l_path```: the path of the degraded left image
|
||||
* ```--input_r_path```: the path of the degraded right image
|
||||
* ```--output_l_path```: the path to save the predicted left image
|
||||
* ```--output_r_path```: the path to save the predicted right image
|
||||
* [pretrained models](https://github.com/megvii-research/NAFNet/#results-and-pre-trained-models) should be downloaded.
|
||||
* Try the web demo with all three tasks here: [](https://replicate.com/megvii-research/nafnet)
|
||||
|
||||
### Results and Pre-trained Models
|
||||
|
|
Loading…
Reference in New Issue