[Enhancement] refine mmagic SDKEnd2EndModel forward (#2092)
* update SDKEnd2EndModel of super_resolution_model.py * recover * docstringpull/2094/head
parent
5e9d27b8d6
commit
98f895c21d
|
@ -144,6 +144,7 @@ class SDKEnd2EndModel(End2EndModel):
|
|||
"""SDK inference class, converts SDK output to mmagic format."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.update(dict(data_preprocessor=None))
|
||||
super(SDKEnd2EndModel, self).__init__(*args, **kwargs)
|
||||
|
||||
def forward(self,
|
||||
|
@ -159,7 +160,8 @@ class SDKEnd2EndModel(End2EndModel):
|
|||
it is an image.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensors
|
||||
inputs (torch.Tensor): A list contains input image(s)
|
||||
in [C x H x W] format.
|
||||
data_samples (List[BaseDataElement], optional): The data samples.
|
||||
Defaults to None.
|
||||
mode (str, optional): forward mode, only support `predict`.
|
||||
|
@ -169,28 +171,18 @@ class SDKEnd2EndModel(End2EndModel):
|
|||
Returns:
|
||||
list | dict: High resolution image or a evaluation results.
|
||||
"""
|
||||
|
||||
if hasattr(self.data_preprocessor, 'destructor'):
|
||||
inputs = self.data_preprocessor.destructor(
|
||||
inputs.to(self.data_preprocessor.std.device))
|
||||
|
||||
outputs = []
|
||||
for i in range(inputs.shape[0]):
|
||||
output = self.wrapper.invoke(inputs[i].permute(
|
||||
1, 2, 0).contiguous().detach().cpu().numpy() * 255.)
|
||||
for input in inputs:
|
||||
output = self.wrapper.invoke(
|
||||
input.permute(1, 2, 0).contiguous().detach().cpu().numpy())
|
||||
outputs.append(
|
||||
torch.from_numpy(output).permute(2, 0, 1).contiguous())
|
||||
outputs = torch.stack(outputs, 0) / 255.
|
||||
assert hasattr(self.data_preprocessor, 'destruct')
|
||||
outputs = self.data_preprocessor.destruct(
|
||||
outputs.to(self.data_preprocessor.std.device), data_samples)
|
||||
outputs = torch.stack(outputs, 0)
|
||||
outputs = DataSample(pred_img=outputs.cpu()).split()
|
||||
|
||||
# create a stacked data sample here
|
||||
predictions = DataSample(pred_img=outputs.cpu())
|
||||
|
||||
predictions = self.convert_to_datasample_list(predictions,
|
||||
data_samples, inputs)
|
||||
return predictions
|
||||
for data_sample, pred in zip(data_samples, outputs):
|
||||
data_sample.output = pred
|
||||
return data_samples
|
||||
|
||||
|
||||
def build_super_resolution_model(
|
||||
|
|
|
@ -685,8 +685,8 @@ class SDKEnd2EndModel(End2EndModel):
|
|||
"""Run forward inference.
|
||||
|
||||
Args:
|
||||
img (Sequence[Tensor]): A list contains input image(s)
|
||||
in [N x C x H x W] format.
|
||||
inputs (Sequence[Tensor]): A list contains input image(s)
|
||||
in [C x H x W] format.
|
||||
data_samples (List[BaseDataElement]): A list of meta info
|
||||
for image(s).
|
||||
*args: Other arguments.
|
||||
|
|
Loading…
Reference in New Issue