parent
d26da4b68c
commit
c731d54e18
|
@ -63,14 +63,14 @@ def main():
|
|||
elif op_name in ['SRResize']:
|
||||
op[op_name]['infer_mode'] = True
|
||||
elif op_name == 'KeepKeys':
|
||||
op[op_name]['keep_keys'] = ['imge_lr']
|
||||
op[op_name]['keep_keys'] = ['img_lr']
|
||||
transforms.append(op)
|
||||
global_config['infer_mode'] = True
|
||||
ops = create_operators(transforms, global_config)
|
||||
|
||||
save_res_path = config['Global'].get('save_res_path', "./infer_result")
|
||||
if not os.path.exists(os.path.dirname(save_res_path)):
|
||||
os.makedirs(os.path.dirname(save_res_path))
|
||||
save_visual_path = config['Global'].get('save_visual', "infer_result/")
|
||||
if not os.path.exists(os.path.dirname(save_visual_path)):
|
||||
os.makedirs(os.path.dirname(save_visual_path))
|
||||
|
||||
model.eval()
|
||||
for file in get_image_file_list(config['Global']['infer_img']):
|
||||
|
@ -87,7 +87,7 @@ def main():
|
|||
fm_sr = (sr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
|
||||
fm_lr = (lr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
|
||||
img_name_pure = os.path.split(file)[-1]
|
||||
cv2.imwrite("infer_result/sr_{}".format(img_name_pure),
|
||||
cv2.imwrite("{}/sr_{}".format(save_visual_path, img_name_pure),
|
||||
fm_sr[:, :, ::-1])
|
||||
logger.info("The visualized image saved in infer_result/sr_{}".format(
|
||||
img_name_pure))
|
||||
|
|
|
@ -231,7 +231,8 @@ def train(config,
|
|||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
extra_input_models = [
|
||||
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN", "RobustScanner"
|
||||
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
|
||||
"RobustScanner"
|
||||
]
|
||||
extra_input = False
|
||||
if config['Architecture']['algorithm'] == 'Distillation':
|
||||
|
@ -503,16 +504,6 @@ def eval(model,
|
|||
preds = model(batch)
|
||||
sr_img = preds["sr_img"]
|
||||
lr_img = preds["lr_img"]
|
||||
|
||||
for i in (range(sr_img.shape[0])):
|
||||
fm_sr = (sr_img[i].numpy() * 255).transpose(
|
||||
1, 2, 0).astype(np.uint8)
|
||||
fm_lr = (lr_img[i].numpy() * 255).transpose(
|
||||
1, 2, 0).astype(np.uint8)
|
||||
cv2.imwrite("output/images/{}_{}_sr.jpg".format(
|
||||
sum_images, i), fm_sr)
|
||||
cv2.imwrite("output/images/{}_{}_lr.jpg".format(
|
||||
sum_images, i), fm_lr)
|
||||
else:
|
||||
preds = model(images)
|
||||
preds = to_float32(preds)
|
||||
|
@ -525,16 +516,6 @@ def eval(model,
|
|||
preds = model(batch)
|
||||
sr_img = preds["sr_img"]
|
||||
lr_img = preds["lr_img"]
|
||||
|
||||
for i in (range(sr_img.shape[0])):
|
||||
fm_sr = (sr_img[i].numpy() * 255).transpose(
|
||||
1, 2, 0).astype(np.uint8)
|
||||
fm_lr = (lr_img[i].numpy() * 255).transpose(
|
||||
1, 2, 0).astype(np.uint8)
|
||||
cv2.imwrite("output/images/{}_{}_sr.jpg".format(
|
||||
sum_images, i), fm_sr)
|
||||
cv2.imwrite("output/images/{}_{}_lr.jpg".format(
|
||||
sum_images, i), fm_lr)
|
||||
else:
|
||||
preds = model(images)
|
||||
|
||||
|
|
Loading…
Reference in New Issue