[Fix] MMEditing cannot save results when testing (#336)

* fix show

* lint

* remove redundant codes

* resolve comment

* type hint
pull/360/head
Yifan Zhou 2022-04-14 20:25:31 +08:00 committed by GitHub
parent 6e7e219b0b
commit d7adf815a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 17 additions and 1 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Sequence, Union
import os.path as osp
from typing import Dict, List, Optional, Sequence, Union
import mmcv
import numpy as np
@ -88,6 +89,8 @@ class End2EndModel(BaseBackendModel):
def forward_test(self,
lq: torch.Tensor,
gt: Optional[torch.Tensor] = None,
meta: List[Dict] = None,
save_path=None,
*args,
**kwargs):
"""Run inference for restorer to generate evaluation result.
@ -96,6 +99,8 @@ class End2EndModel(BaseBackendModel):
lq (torch.Tensor): The input low-quality image of the model.
gt (torch.Tensor): The ground truth of input image. Defaults to
`None`.
meta (List[Dict]): The meta infomations of MMEditing.
save_path (str): Path to save image. Default: None.
*args: Other arguments.
**kwargs: Other key-pair arguments.
@ -104,6 +109,17 @@ class End2EndModel(BaseBackendModel):
"""
outputs = self.forward_dummy(lq)
result = self.test_post_process(outputs, lq, gt)
# Align to mmediting BasicRestorer
if save_path:
outputs = [torch.from_numpy(i) for i in outputs]
lq_path = meta[0]['lq_path']
folder_name = osp.splitext(osp.basename(lq_path))[0]
save_path = osp.join(save_path, f'{folder_name}.png')
mmcv.imwrite(tensor2img(outputs), save_path)
return result
def forward_dummy(self, lq: torch.Tensor, *args, **kwargs):