119 lines
4.0 KiB
Python
119 lines
4.0 KiB
Python
|
# ------------------------------------------------------------------------
|
||
|
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
||
|
# ------------------------------------------------------------------------
|
||
|
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||
|
# Copyright 2018-2020 BasicSR Authors
|
||
|
# ------------------------------------------------------------------------
|
||
|
import torch
|
||
|
|
||
|
# from basicsr.data import create_dataloader, create_dataset
|
||
|
from basicsr.models import create_model
|
||
|
from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, tensor2img, imwrite, set_random_seed
|
||
|
|
||
|
import argparse
|
||
|
from basicsr.utils.options import dict2str, parse
|
||
|
from basicsr.utils.dist_util import get_dist_info, init_dist
|
||
|
import random
|
||
|
|
||
|
def parse_options(is_train=True):
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument(
|
||
|
'-opt', type=str, required=True, help='Path to option YAML file.')
|
||
|
parser.add_argument(
|
||
|
'--launcher',
|
||
|
choices=['none', 'pytorch', 'slurm'],
|
||
|
default='none',
|
||
|
help='job launcher')
|
||
|
parser.add_argument('--local_rank', type=int, default=0)
|
||
|
|
||
|
parser.add_argument('--input_l_path', type=str, required=True, help='The path to the input left image. For stereo image inference only.')
|
||
|
parser.add_argument('--input_r_path', type=str, required=True, help='The path to the input right image. For stereo image inference only.')
|
||
|
parser.add_argument('--output_l_path', type=str, required=True, help='The path to the output left image. For stereo image inference only.')
|
||
|
parser.add_argument('--output_r_path', type=str, required=True, help='The path to the output right image. For stereo image inference only.')
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
opt = parse(args.opt, is_train=is_train)
|
||
|
|
||
|
# distributed settings
|
||
|
if args.launcher == 'none':
|
||
|
opt['dist'] = False
|
||
|
print('Disable distributed.', flush=True)
|
||
|
else:
|
||
|
opt['dist'] = True
|
||
|
if args.launcher == 'slurm' and 'dist_params' in opt:
|
||
|
init_dist(args.launcher, **opt['dist_params'])
|
||
|
else:
|
||
|
init_dist(args.launcher)
|
||
|
print('init dist .. ', args.launcher)
|
||
|
|
||
|
opt['rank'], opt['world_size'] = get_dist_info()
|
||
|
|
||
|
# random seed
|
||
|
seed = opt.get('manual_seed')
|
||
|
if seed is None:
|
||
|
seed = random.randint(1, 10000)
|
||
|
opt['manual_seed'] = seed
|
||
|
set_random_seed(seed + opt['rank'])
|
||
|
|
||
|
opt['img_path'] = {
|
||
|
'input_l': args.input_l_path,
|
||
|
'input_r': args.input_r_path,
|
||
|
'output_l': args.output_l_path,
|
||
|
'output_r': args.output_r_path
|
||
|
}
|
||
|
|
||
|
return opt
|
||
|
|
||
|
def imread(img_path):
|
||
|
file_client = FileClient('disk')
|
||
|
img_bytes = file_client.get(img_path, None)
|
||
|
try:
|
||
|
img = imfrombytes(img_bytes, float32=True)
|
||
|
except:
|
||
|
raise Exception("path {} not working".format(img_path))
|
||
|
|
||
|
img = img2tensor(img, bgr2rgb=True, float32=True)
|
||
|
return img
|
||
|
|
||
|
def main():
|
||
|
# parse options, set distributed setting, set ramdom seed
|
||
|
opt = parse_options(is_train=False)
|
||
|
|
||
|
img_l_path = opt['img_path'].get('input_l')
|
||
|
img_r_path = opt['img_path'].get('input_r')
|
||
|
output_l_path = opt['img_path'].get('output_l')
|
||
|
output_r_path = opt['img_path'].get('output_r')
|
||
|
|
||
|
## 1. read image
|
||
|
img_l = imread(img_l_path)
|
||
|
img_r = imread(img_r_path)
|
||
|
img = torch.cat([img_l, img_r], dim=0)
|
||
|
|
||
|
## 2. run inference
|
||
|
opt['dist'] = False
|
||
|
model = create_model(opt)
|
||
|
|
||
|
model.feed_data(data={'lq': img.unsqueeze(dim=0)})
|
||
|
|
||
|
if model.opt['val'].get('grids', False):
|
||
|
model.grids()
|
||
|
|
||
|
model.test()
|
||
|
|
||
|
if model.opt['val'].get('grids', False):
|
||
|
model.grids_inverse()
|
||
|
|
||
|
visuals = model.get_current_visuals()
|
||
|
sr_img_l = visuals['result'][:,:3]
|
||
|
sr_img_r = visuals['result'][:,3:]
|
||
|
sr_img_l, sr_img_r = tensor2img([sr_img_l, sr_img_r])
|
||
|
imwrite(sr_img_l, output_l_path)
|
||
|
imwrite(sr_img_r, output_r_path)
|
||
|
|
||
|
print(f'inference {img_l_path} .. finished. saved to {output_l_path}')
|
||
|
print(f'inference {img_r_path} .. finished. saved to {output_r_path}')
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|
||
|
|