[Enhancement] refine mmagic SDKEnd2EndModel forward (#2092)

* update SDKEnd2EndModel of super_resolution_model.py

* recover

* docstring
pull/2094/head
AllentDan 2023-05-19 17:11:51 +08:00 committed by GitHub
parent 5e9d27b8d6
commit 98f895c21d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 21 deletions

View File

@ -144,6 +144,7 @@ class SDKEnd2EndModel(End2EndModel):
"""SDK inference class, converts SDK output to mmagic format."""
def __init__(self, *args, **kwargs):
kwargs.update(dict(data_preprocessor=None))
super(SDKEnd2EndModel, self).__init__(*args, **kwargs)
def forward(self,
@ -159,7 +160,8 @@ class SDKEnd2EndModel(End2EndModel):
it is an image.
Args:
inputs (torch.Tensor): The input tensors
inputs (torch.Tensor): A list contains input image(s)
in [C x H x W] format.
data_samples (List[BaseDataElement], optional): The data samples.
Defaults to None.
mode (str, optional): forward mode, only support `predict`.
@ -169,28 +171,18 @@ class SDKEnd2EndModel(End2EndModel):
Returns:
list | dict: High resolution image or a evaluation results.
"""
if hasattr(self.data_preprocessor, 'destructor'):
inputs = self.data_preprocessor.destructor(
inputs.to(self.data_preprocessor.std.device))
outputs = []
for i in range(inputs.shape[0]):
output = self.wrapper.invoke(inputs[i].permute(
1, 2, 0).contiguous().detach().cpu().numpy() * 255.)
for input in inputs:
output = self.wrapper.invoke(
input.permute(1, 2, 0).contiguous().detach().cpu().numpy())
outputs.append(
torch.from_numpy(output).permute(2, 0, 1).contiguous())
outputs = torch.stack(outputs, 0) / 255.
assert hasattr(self.data_preprocessor, 'destruct')
outputs = self.data_preprocessor.destruct(
outputs.to(self.data_preprocessor.std.device), data_samples)
outputs = torch.stack(outputs, 0)
outputs = DataSample(pred_img=outputs.cpu()).split()
# create a stacked data sample here
predictions = DataSample(pred_img=outputs.cpu())
predictions = self.convert_to_datasample_list(predictions,
data_samples, inputs)
return predictions
for data_sample, pred in zip(data_samples, outputs):
data_sample.output = pred
return data_samples
def build_super_resolution_model(

View File

@ -685,8 +685,8 @@ class SDKEnd2EndModel(End2EndModel):
"""Run forward inference.
Args:
img (Sequence[Tensor]): A list contains input image(s)
in [N x C x H x W] format.
inputs (Sequence[Tensor]): A list contains input image(s)
in [C x H x W] format.
data_samples (List[BaseDataElement]): A list of meta info
for image(s).
*args: Other arguments.