mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Fix] MMEditing cannot save results when testing (#336)
* fix show * lint * remove redundant codes * resolve comment * type hint
This commit is contained in:
parent
6e7e219b0b
commit
d7adf815a0
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import List, Optional, Sequence, Union
|
import os.path as osp
|
||||||
|
from typing import Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -88,6 +89,8 @@ class End2EndModel(BaseBackendModel):
|
|||||||
def forward_test(self,
|
def forward_test(self,
|
||||||
lq: torch.Tensor,
|
lq: torch.Tensor,
|
||||||
gt: Optional[torch.Tensor] = None,
|
gt: Optional[torch.Tensor] = None,
|
||||||
|
meta: List[Dict] = None,
|
||||||
|
save_path=None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Run inference for restorer to generate evaluation result.
|
"""Run inference for restorer to generate evaluation result.
|
||||||
@ -96,6 +99,8 @@ class End2EndModel(BaseBackendModel):
|
|||||||
lq (torch.Tensor): The input low-quality image of the model.
|
lq (torch.Tensor): The input low-quality image of the model.
|
||||||
gt (torch.Tensor): The ground truth of input image. Defaults to
|
gt (torch.Tensor): The ground truth of input image. Defaults to
|
||||||
`None`.
|
`None`.
|
||||||
|
meta (List[Dict]): The meta infomations of MMEditing.
|
||||||
|
save_path (str): Path to save image. Default: None.
|
||||||
*args: Other arguments.
|
*args: Other arguments.
|
||||||
**kwargs: Other key-pair arguments.
|
**kwargs: Other key-pair arguments.
|
||||||
|
|
||||||
@ -104,6 +109,17 @@ class End2EndModel(BaseBackendModel):
|
|||||||
"""
|
"""
|
||||||
outputs = self.forward_dummy(lq)
|
outputs = self.forward_dummy(lq)
|
||||||
result = self.test_post_process(outputs, lq, gt)
|
result = self.test_post_process(outputs, lq, gt)
|
||||||
|
|
||||||
|
# Align to mmediting BasicRestorer
|
||||||
|
if save_path:
|
||||||
|
outputs = [torch.from_numpy(i) for i in outputs]
|
||||||
|
|
||||||
|
lq_path = meta[0]['lq_path']
|
||||||
|
folder_name = osp.splitext(osp.basename(lq_path))[0]
|
||||||
|
save_path = osp.join(save_path, f'{folder_name}.png')
|
||||||
|
|
||||||
|
mmcv.imwrite(tensor2img(outputs), save_path)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def forward_dummy(self, lq: torch.Tensor, *args, **kwargs):
|
def forward_dummy(self, lq: torch.Tensor, *args, **kwargs):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user