# ------------------------------------------------------------------------ # Copyright (c) 2022 megvii-model. All Rights Reserved. # ------------------------------------------------------------------------ # Modified from BasicSR (https://github.com/xinntao/BasicSR) # Copyright 2018-2020 BasicSR Authors # ------------------------------------------------------------------------ import torch # from basicsr.data import create_dataloader, create_dataset from basicsr.models import create_model from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, tensor2img, imwrite, set_random_seed import argparse from basicsr.utils.options import dict2str, parse from basicsr.utils.dist_util import get_dist_info, init_dist import random def parse_options(is_train=True): parser = argparse.ArgumentParser() parser.add_argument( '-opt', type=str, required=True, help='Path to option YAML file.') parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--input_l_path', type=str, required=True, help='The path to the input left image. For stereo image inference only.') parser.add_argument('--input_r_path', type=str, required=True, help='The path to the input right image. For stereo image inference only.') parser.add_argument('--output_l_path', type=str, required=True, help='The path to the output left image. For stereo image inference only.') parser.add_argument('--output_r_path', type=str, required=True, help='The path to the output right image. For stereo image inference only.') args = parser.parse_args() opt = parse(args.opt, is_train=is_train) # distributed settings if args.launcher == 'none': opt['dist'] = False print('Disable distributed.', flush=True) else: opt['dist'] = True if args.launcher == 'slurm' and 'dist_params' in opt: init_dist(args.launcher, **opt['dist_params']) else: init_dist(args.launcher) print('init dist .. ', args.launcher) opt['rank'], opt['world_size'] = get_dist_info() # random seed seed = opt.get('manual_seed') if seed is None: seed = random.randint(1, 10000) opt['manual_seed'] = seed set_random_seed(seed + opt['rank']) opt['img_path'] = { 'input_l': args.input_l_path, 'input_r': args.input_r_path, 'output_l': args.output_l_path, 'output_r': args.output_r_path } return opt def imread(img_path): file_client = FileClient('disk') img_bytes = file_client.get(img_path, None) try: img = imfrombytes(img_bytes, float32=True) except: raise Exception("path {} not working".format(img_path)) img = img2tensor(img, bgr2rgb=True, float32=True) return img def main(): # parse options, set distributed setting, set ramdom seed opt = parse_options(is_train=False) img_l_path = opt['img_path'].get('input_l') img_r_path = opt['img_path'].get('input_r') output_l_path = opt['img_path'].get('output_l') output_r_path = opt['img_path'].get('output_r') ## 1. read image img_l = imread(img_l_path) img_r = imread(img_r_path) img = torch.cat([img_l, img_r], dim=0) ## 2. run inference opt['dist'] = False model = create_model(opt) model.feed_data(data={'lq': img.unsqueeze(dim=0)}) if model.opt['val'].get('grids', False): model.grids() model.test() if model.opt['val'].get('grids', False): model.grids_inverse() visuals = model.get_current_visuals() sr_img_l = visuals['result'][:,:3] sr_img_r = visuals['result'][:,3:] sr_img_l, sr_img_r = tensor2img([sr_img_l, sr_img_r]) imwrite(sr_img_l, output_l_path) imwrite(sr_img_r, output_r_path) print(f'inference {img_l_path} .. finished. saved to {output_l_path}') print(f'inference {img_r_path} .. finished. saved to {output_r_path}') if __name__ == '__main__': main()