diff --git a/mmdeploy/codebase/mmedit/deploy/super_resolution_model.py b/mmdeploy/codebase/mmedit/deploy/super_resolution_model.py index ade5d0bee..454de8b95 100644 --- a/mmdeploy/codebase/mmedit/deploy/super_resolution_model.py +++ b/mmdeploy/codebase/mmedit/deploy/super_resolution_model.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional, Sequence, Union +import os.path as osp +from typing import Dict, List, Optional, Sequence, Union import mmcv import numpy as np @@ -88,6 +89,8 @@ class End2EndModel(BaseBackendModel): def forward_test(self, lq: torch.Tensor, gt: Optional[torch.Tensor] = None, + meta: List[Dict] = None, + save_path=None, *args, **kwargs): """Run inference for restorer to generate evaluation result. @@ -96,6 +99,8 @@ class End2EndModel(BaseBackendModel): lq (torch.Tensor): The input low-quality image of the model. gt (torch.Tensor): The ground truth of input image. Defaults to `None`. + meta (List[Dict]): The meta infomations of MMEditing. + save_path (str): Path to save image. Default: None. *args: Other arguments. **kwargs: Other key-pair arguments. @@ -104,6 +109,17 @@ class End2EndModel(BaseBackendModel): """ outputs = self.forward_dummy(lq) result = self.test_post_process(outputs, lq, gt) + + # Align to mmediting BasicRestorer + if save_path: + outputs = [torch.from_numpy(i) for i in outputs] + + lq_path = meta[0]['lq_path'] + folder_name = osp.splitext(osp.basename(lq_path))[0] + save_path = osp.join(save_path, f'{folder_name}.png') + + mmcv.imwrite(tensor2img(outputs), save_path) + return result def forward_dummy(self, lq: torch.Tensor, *args, **kwargs):