inference for all the images in one folder is added.
parent
2b4af71ebe
commit
33a296c7fa
|
@ -0,0 +1,54 @@
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Copyright (c) 2022 megvii-model. All Rights Reserved.
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||||||
|
# Copyright 2018-2020 BasicSR Authors
|
||||||
|
# ------------------------------------------------------------------------
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# from basicsr.data import create_dataloader, create_dataset
|
||||||
|
from basicsr.models import create_model
|
||||||
|
from basicsr.train import parse_options
|
||||||
|
from basicsr.utils import FileClient, imfrombytes, img2tensor, padding, tensor2img, imwrite
|
||||||
|
import os
|
||||||
|
|
||||||
|
# from basicsr.utils import (get_env_info, get_root_logger, get_time_str,
|
||||||
|
# make_exp_dirs)
|
||||||
|
# from basicsr.utils.options import dict2str
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# parse options, set distributed setting, set ramdom seed
|
||||||
|
opt = parse_options(is_train=False)
|
||||||
|
opt['num_gpu'] = torch.cuda.device_count()
|
||||||
|
input_folder = opt['img_path'].get('input_folder')
|
||||||
|
output_folder = opt['img_path'].get('output_folder')
|
||||||
|
# Get a list of all image files in the input folder
|
||||||
|
image_files = [f for f in os.listdir(input_folder) if os.path.isfile(os.path.join(input_folder, f))]
|
||||||
|
opt['dist'] = False
|
||||||
|
model = create_model(opt)
|
||||||
|
for image_file in image_files:
|
||||||
|
# Construct the input and output paths for each image
|
||||||
|
img_path = os.path.join(input_folder, image_file)
|
||||||
|
output_path = os.path.join(output_folder, image_file)
|
||||||
|
## 1. read image
|
||||||
|
file_client = FileClient('disk')
|
||||||
|
img_bytes = file_client.get(img_path, None)
|
||||||
|
try:
|
||||||
|
img = imfrombytes(img_bytes, float32=True)
|
||||||
|
except:
|
||||||
|
raise Exception("path {} not working".format(img_path))
|
||||||
|
img = img2tensor(img, bgr2rgb=True, float32=True)
|
||||||
|
## 2. run inference
|
||||||
|
|
||||||
|
model.feed_data(data={'lq': img.unsqueeze(dim=0)})
|
||||||
|
if model.opt['val'].get('grids', False):
|
||||||
|
model.grids()
|
||||||
|
model.test()
|
||||||
|
if model.opt['val'].get('grids', False):
|
||||||
|
model.grids_inverse()
|
||||||
|
visuals = model.get_current_visuals()
|
||||||
|
sr_img = tensor2img([visuals['result']])
|
||||||
|
imwrite(sr_img, output_path)
|
||||||
|
print(f'inference {img_path} .. finished. saved to {output_path}')
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -35,9 +35,10 @@ def parse_options(is_train=True):
|
||||||
default='none',
|
default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
||||||
parser.add_argument('--input_path', type=str, required=False, help='The path to the input image. For single image inference only.')
|
parser.add_argument('--input_path', type=str, required=False, help='The path to the input image. For single image inference only.')
|
||||||
|
parser.add_argument('--input_folder', type=str, required=False, help='The path to the input folder. For multiple image inference.')
|
||||||
parser.add_argument('--output_path', type=str, required=False, help='The path to the output image. For single image inference only.')
|
parser.add_argument('--output_path', type=str, required=False, help='The path to the output image. For single image inference only.')
|
||||||
|
parser.add_argument('--output_folder', type=str, required=False, help='The path to the output folder. For multiple image inference.')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = parse(args.opt, is_train=is_train)
|
opt = parse(args.opt, is_train=is_train)
|
||||||
|
@ -68,6 +69,11 @@ def parse_options(is_train=True):
|
||||||
'input_img': args.input_path,
|
'input_img': args.input_path,
|
||||||
'output_img': args.output_path
|
'output_img': args.output_path
|
||||||
}
|
}
|
||||||
|
elif args.input_folder is not None and args.output_folder is not None:
|
||||||
|
opt['img_path'] = {
|
||||||
|
'input_folder': args.input_folder,
|
||||||
|
'output_folder': args.output_folder
|
||||||
|
}
|
||||||
|
|
||||||
return opt
|
return opt
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# GENERATED VERSION FILE
|
# GENERATED VERSION FILE
|
||||||
# TIME: Mon Apr 18 21:35:20 2022
|
# TIME: Mon Jun 17 23:20:12 2024
|
||||||
__version__ = '1.2.0+386ca20'
|
__version__ = '1.2.0+2b4af71'
|
||||||
short_version = '1.2.0'
|
short_version = '1.2.0'
|
||||||
version_info = (1, 2, 0)
|
version_info = (1, 2, 0)
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 61 KiB After Width: | Height: | Size: 61 KiB |
Loading…
Reference in New Issue