[Fix] MMEditing cannot save results when testing (#336)
* fix show * lint * remove redundant codes * resolve comment * type hintpull/360/head
parent
6e7e219b0b
commit
d7adf815a0
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Sequence, Union
|
||||
import os.path as osp
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -88,6 +89,8 @@ class End2EndModel(BaseBackendModel):
|
|||
def forward_test(self,
|
||||
lq: torch.Tensor,
|
||||
gt: Optional[torch.Tensor] = None,
|
||||
meta: List[Dict] = None,
|
||||
save_path=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""Run inference for restorer to generate evaluation result.
|
||||
|
@ -96,6 +99,8 @@ class End2EndModel(BaseBackendModel):
|
|||
lq (torch.Tensor): The input low-quality image of the model.
|
||||
gt (torch.Tensor): The ground truth of input image. Defaults to
|
||||
`None`.
|
||||
meta (List[Dict]): The meta infomations of MMEditing.
|
||||
save_path (str): Path to save image. Default: None.
|
||||
*args: Other arguments.
|
||||
**kwargs: Other key-pair arguments.
|
||||
|
||||
|
@ -104,6 +109,17 @@ class End2EndModel(BaseBackendModel):
|
|||
"""
|
||||
outputs = self.forward_dummy(lq)
|
||||
result = self.test_post_process(outputs, lq, gt)
|
||||
|
||||
# Align to mmediting BasicRestorer
|
||||
if save_path:
|
||||
outputs = [torch.from_numpy(i) for i in outputs]
|
||||
|
||||
lq_path = meta[0]['lq_path']
|
||||
folder_name = osp.splitext(osp.basename(lq_path))[0]
|
||||
save_path = osp.join(save_path, f'{folder_name}.png')
|
||||
|
||||
mmcv.imwrite(tensor2img(outputs), save_path)
|
||||
|
||||
return result
|
||||
|
||||
def forward_dummy(self, lq: torch.Tensor, *args, **kwargs):
|
||||
|
|
Loading…
Reference in New Issue